diff --git a/tuning/aim_loader.py b/tuning/aim_loader.py deleted file mode 100644 index b1a86f5f5..000000000 --- a/tuning/aim_loader.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright The IBM Tuning Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Standard -import os - -# Third Party -from aim.hugging_face import AimCallback - - -def get_aimstack_callback(): - # Initialize a new run - aim_server = os.environ.get("AIMSTACK_SERVER") - aim_db = os.environ.get("AIMSTACK_DB") - aim_experiment = os.environ.get("AIMSTACK_EXPERIMENT") - if aim_experiment is None: - aim_experiment = "" - - if aim_server: - aim_callback = AimCallback( - repo="aim://" + aim_server + "/", experiment=aim_experiment - ) - if aim_db: - aim_callback = AimCallback(repo=aim_db, experiment=aim_experiment) - else: - aim_callback = AimCallback(experiment=aim_experiment) - - return aim_callback diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 00a536568..8bd81865b 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -72,3 +72,11 @@ class TrainingArguments(transformers.TrainingArguments): default=False, metadata={"help": "Packing to be enabled in SFT Trainer, default is False"}, ) + tracker: str.lower = field( + default=None, + metadata={ + "help": "Experiment tracker to use.\n" + \ + "Available trackers are - aim, none\n" + \ + "Requires additional configs, see tuning.configs/tracker_configs.py" + }, + ) diff --git a/tuning/config/tracker_configs.py b/tuning/config/tracker_configs.py new file mode 100644 index 000000000..5ce243458 --- /dev/null +++ b/tuning/config/tracker_configs.py @@ -0,0 +1,31 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from dataclasses import dataclass + +@dataclass +class AimConfig: + # Name of the experiment + experiment: str = None + # 'aim_repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server. + # When 'aim_remote_server_ip' or 'aim_remote_server_port' is set, it designates a remote aim repo. + # Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'. + # See https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html for documentation on Aim remote server tracking. + aim_repo: str = ".aim" + aim_remote_server_ip: str = None + aim_remote_server_port: int = None + # Location of where run_hash is exported, if unspecified this is output to + # training_args.output_dir/.aim_run_hash if the output_dir is set else not exported. + aim_run_hash_export_path: str = None diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 29c5fd299..be5f35a4a 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -14,9 +14,10 @@ # Standard from datetime import datetime -from typing import Optional, Union +from typing import Dict, List, Optional, Union import json import os +import time import sys # Third Party @@ -37,12 +38,15 @@ import transformers # Local -from tuning.aim_loader import get_aimstack_callback -from tuning.config import configs, peft_config +from tuning.config import configs, peft_config, tracker_configs from tuning.data import tokenizer_data_utils +from tuning.trackers.tracker import Tracker +from tuning.trackers.tracker_factory import get_tracker from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype +logger = logging.get_logger("sft_trainer") + class FileLoggingCallback(TrainerCallback): """Exports metrics, e.g., training loss to a file in the checkpoint directory.""" @@ -95,6 +99,9 @@ def train( peft_config: Optional[ # pylint: disable=redefined-outer-name Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] ] = None, + callbacks: Optional[List[TrainerCallback]] = None, + tracker: Optional[Tracker] = None, + exp_metadata: Optional[Dict] = None, ): """Call the SFTTrainer @@ -106,10 +113,13 @@ def train( peft_config.PromptTuningConfig for prompt tuning | \ None for fine tuning The peft configuration to pass to trainer + callbacks: List of callbacks to attach with SFTtrainer. + tracker: One of the available trackers in tuning.trackers.tracker_factory.REGISTERED_TRACKERS + Initialized using tuning.trackers.tracker_factory.get_tracker + Using configs in tuning.config.tracker_configs + exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker. """ - logger = logging.get_logger("sft_trainer") - # Validate parameters if (not isinstance(train_args.num_train_epochs, float)) or ( train_args.num_train_epochs <= 0 @@ -121,12 +131,16 @@ def train( raise ValueError("gradient_accumulation_steps has to be an integer >= 1") task_type = "CAUSAL_LM" + additional_metrics = {} + + model_load_time = time.time() model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=train_args.cache_dir, torch_dtype=get_torch_dtype(model_args.torch_dtype), use_flash_attention_2=model_args.use_flash_attn, ) + additional_metrics["model_load_time"] = time.time() - model_load_time peft_config = get_hf_peft_config(task_type, peft_config) @@ -218,10 +232,6 @@ def train( "Validation dataset length is %s", len(formatted_validation_dataset) ) - aim_callback = get_aimstack_callback() - file_logger_callback = FileLoggingCallback(logger) - callbacks = [aim_callback, file_logger_callback] - if train_args.packing: logger.info("Packing is set to True") data_collator = None @@ -261,6 +271,16 @@ def train( peft_config=peft_config, ) + # We track additional metrics and experiment metadata after + # Trainer object creation to ensure that this is not repeated + # multiple times for FSDP runs. + if tracker is not None: + # Currently tracked only on process zero. + if trainer.is_world_process_zero(): + for k, v in additional_metrics.items(): + tracker.track(metric=v, name=k, stage="additional_metrics") + tracker.set_params(params=exp_metadata, name="experiment_metadata") + if trainer.is_fsdp_enabled and peft_config is not None: trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( model @@ -276,6 +296,7 @@ def main(**kwargs): # pylint: disable=unused-argument configs.TrainingArguments, peft_config.LoraConfig, peft_config.PromptTuningConfig, + tracker_configs.AimConfig, ) ) parser.add_argument( @@ -284,22 +305,69 @@ def main(**kwargs): # pylint: disable=unused-argument choices=["pt", "lora", None, "none"], default="pt", ) + parser.add_argument( + "--exp_metadata", + type=str, + default=None, + help='Pass a json string representing K:V pairs to be associated to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'', + ) ( model_args, data_args, training_args, lora_config, prompt_tuning_config, - peft_method, + aim_config, + additional, _, ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) - if peft_method.peft_method == "lora": + + peft_method = additional.peft_method + if peft_method == "lora": tune_config = lora_config - elif peft_method.peft_method == "pt": + elif peft_method == "pt": tune_config = prompt_tuning_config else: tune_config = None - train(model_args, data_args, training_args, tune_config) + + tracker_name = training_args.tracker + if tracker_name == "aim": + tracker_config = aim_config + else: + tracker_config = None + + # Initialize callbacks + file_logger_callback = FileLoggingCallback(logger) + callbacks = [file_logger_callback] + + # Initialize the tracker + tracker = get_tracker(tracker_name, tracker_config) + tracker_callback = tracker.get_hf_callback() + if tracker_callback is not None: + callbacks.append(tracker_callback) + + # extra metadata passed via client + metadata = None + if additional.exp_metadata is not None: + try: + metadata = json.loads(additional.exp_metadata) + if metadata is None or not isinstance(metadata, Dict): + logger.warning( + "metadata cannot be converted to simple k:v dict ignoring" + ) + metadata = None + except: + logger.error("failed while parsing extra metadata. pass a valid json") + + train( + model_args=model_args, + data_args=data_args, + train_args=training_args, + peft_config=tune_config, + callbacks=callbacks, + tracker=tracker, + exp_metadata=metadata, + ) if __name__ == "__main__": diff --git a/tuning/trackers/__init__.py b/tuning/trackers/__init__.py new file mode 100644 index 000000000..55b377746 --- /dev/null +++ b/tuning/trackers/__init__.py @@ -0,0 +1,13 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/tuning/trackers/aimstack_tracker.py b/tuning/trackers/aimstack_tracker.py new file mode 100644 index 000000000..21e66e5fe --- /dev/null +++ b/tuning/trackers/aimstack_tracker.py @@ -0,0 +1,95 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import os + +# Third Party +from aim.hugging_face import AimCallback + +# Local +from .tracker import Tracker +from tuning.config.tracker_configs import AimConfig + +class CustomAimCallback(AimCallback): + + # A path to export run hash generated by Aim + # This is used to link back to the expriments from outside aimstack + hash_export_path = None + + def on_init_end(self, args, state, control, **kwargs): + + if state and not state.is_world_process_zero: + return + + self.setup() # initializes the run_hash + + # Store the run hash + # Change default run hash path to output directory + if self.hash_export_path is None: + if args and args.output_dir: + # args.output_dir/.aim_run_hash + self.hash_export_path = os.path.join( + args.output_dir, ".aim_run_hash" + ) + + if self.hash_export_path: + with open(self.hash_export_path, "w") as f: + hash = self.experiment.hash + f.write('{"run_hash":"' + str(hash) + '"}\n') + + def on_train_begin(self, args, state, control, model=None, **kwargs): + # call directly to make sure hyper parameters and model info is recorded. + self.setup(args=args, state=state, model=model) + +class AimStackTracker(Tracker): + def __init__(self, tracker_config: AimConfig): + super().__init__(name="aim", tracker_config=tracker_config) + + def get_hf_callback(self): + c = self.config + exp = c.experiment + ip = c.aim_remote_server_ip + port = c.aim_remote_server_port + repo = c.aim_repo + hash_export_path = c.aim_run_hash_export_path + + if ip is not None and port is not None: + aim_callback = CustomAimCallback( + repo="aim://" + ip + ":" + port + "/", experiment=exp + ) + if repo: + aim_callback = CustomAimCallback(repo=repo, experiment=exp) + else: + aim_callback = CustomAimCallback(experiment=exp) + + aim_callback.hash_export_path = hash_export_path + self.hf_callback = aim_callback + return self.hf_callback + + def track(self, metric, name, stage="additional_metrics"): + context = {"subset": stage} + callback = self.hf_callback + run = callback.experiment + if run is not None: + run.track(metric, name=name, context=context) + + def set_params(self, params, name="extra_params"): + try: + callback = self.hf_callback + run = callback.experiment + if run is not None: + [run.set((name, key), value, strict=False) for key, value in params.items()] + except Exception as e: + raise e diff --git a/tuning/trackers/tracker.py b/tuning/trackers/tracker.py new file mode 100644 index 000000000..162d793c9 --- /dev/null +++ b/tuning/trackers/tracker.py @@ -0,0 +1,35 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generic Tracker API +class Tracker: + def __init__(self, name=None, tracker_config=None) -> None: + if tracker_config is not None: + self.config = tracker_config + if name is None: + self._name = "None" + else: + self._name = name + + # we use args here to denote any argument. + def get_hf_callback(self): + return None + + def track(self, metric, name, stage): + pass + + # Object passed here is supposed to be a KV object + # for the parameters to be associated with a run + def set_params(self, params, name): + pass diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py new file mode 100644 index 000000000..917d1741d --- /dev/null +++ b/tuning/trackers/tracker_factory.py @@ -0,0 +1,27 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .aimstack_tracker import AimStackTracker +from .tracker import Tracker + +REGISTERED_TRACKERS = {"aim": AimStackTracker} + + +def get_tracker(tracker_name, tracker_config): + if tracker_name in REGISTERED_TRACKERS: + T = REGISTERED_TRACKERS[tracker_name] + return T(tracker_config) + else: + return Tracker()