-
Notifications
You must be signed in to change notification settings - Fork 65
feat: change tracker API to initialize tracker early and track additional metrics. #50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1836a67
e2a2e98
2d5ffa5
80453c4
9357243
b0c170c
4f93a7c
a28aa7b
550604d
c93329f
f178c21
2ceb54f
4f5acb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| # Copyright The IBM Tuning Team | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Standard | ||
| from dataclasses import dataclass | ||
|
|
||
| @dataclass | ||
| class AimConfig: | ||
| # Name of the experiment | ||
| experiment: str = None | ||
| # 'aim_repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server. | ||
| # When 'aim_remote_server_ip' or 'aim_remote_server_port' is set, it designates a remote aim repo. | ||
| # Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'. | ||
| # See https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html for documentation on Aim remote server tracking. | ||
| aim_repo: str = ".aim" | ||
| aim_remote_server_ip: str = None | ||
| aim_remote_server_port: int = None | ||
| # Location of where run_hash is exported, if unspecified this is output to | ||
| # training_args.output_dir/.aim_run_hash if the output_dir is set else not exported. | ||
| aim_run_hash_export_path: str = None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,9 +14,10 @@ | |
|
|
||
| # Standard | ||
| from datetime import datetime | ||
| from typing import Optional, Union | ||
| from typing import Dict, List, Optional, Union | ||
| import json | ||
| import os | ||
| import time | ||
| import sys | ||
|
|
||
| # Third Party | ||
|
|
@@ -37,12 +38,15 @@ | |
| import transformers | ||
|
|
||
| # Local | ||
| from tuning.aim_loader import get_aimstack_callback | ||
| from tuning.config import configs, peft_config | ||
| from tuning.config import configs, peft_config, tracker_configs | ||
| from tuning.data import tokenizer_data_utils | ||
| from tuning.trackers.tracker import Tracker | ||
| from tuning.trackers.tracker_factory import get_tracker | ||
| from tuning.utils.config_utils import get_hf_peft_config | ||
| from tuning.utils.data_type_utils import get_torch_dtype | ||
|
|
||
| logger = logging.get_logger("sft_trainer") | ||
|
|
||
|
|
||
| class FileLoggingCallback(TrainerCallback): | ||
| """Exports metrics, e.g., training loss to a file in the checkpoint directory.""" | ||
|
|
@@ -95,6 +99,9 @@ def train( | |
| peft_config: Optional[ # pylint: disable=redefined-outer-name | ||
| Union[peft_config.LoraConfig, peft_config.PromptTuningConfig] | ||
| ] = None, | ||
| callbacks: Optional[List[TrainerCallback]] = None, | ||
| tracker: Optional[Tracker] = None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using a yaml as config was the first thought that came too. But given HF has chosen to use arguments for most configurations we went with it as a design choice. But if in fms-hf-tuning we choose to use config files for configs, we can do that too.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Train function expects to take callbacks which need to be associated with the training hence main takes the config and initializes tracker (which is needed for the callback) |
||
| exp_metadata: Optional[Dict] = None, | ||
| ): | ||
| """Call the SFTTrainer | ||
|
|
||
|
|
@@ -106,10 +113,13 @@ def train( | |
| peft_config.PromptTuningConfig for prompt tuning | \ | ||
| None for fine tuning | ||
| The peft configuration to pass to trainer | ||
| callbacks: List of callbacks to attach with SFTtrainer. | ||
| tracker: One of the available trackers in tuning.trackers.tracker_factory.REGISTERED_TRACKERS | ||
| Initialized using tuning.trackers.tracker_factory.get_tracker | ||
| Using configs in tuning.config.tracker_configs | ||
| exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker. | ||
| """ | ||
|
|
||
| logger = logging.get_logger("sft_trainer") | ||
|
|
||
| # Validate parameters | ||
| if (not isinstance(train_args.num_train_epochs, float)) or ( | ||
| train_args.num_train_epochs <= 0 | ||
|
|
@@ -121,12 +131,16 @@ def train( | |
| raise ValueError("gradient_accumulation_steps has to be an integer >= 1") | ||
|
|
||
| task_type = "CAUSAL_LM" | ||
| additional_metrics = {} | ||
|
|
||
| model_load_time = time.time() | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_args.model_name_or_path, | ||
| cache_dir=train_args.cache_dir, | ||
| torch_dtype=get_torch_dtype(model_args.torch_dtype), | ||
| use_flash_attention_2=model_args.use_flash_attn, | ||
| ) | ||
| additional_metrics["model_load_time"] = time.time() - model_load_time | ||
|
|
||
| peft_config = get_hf_peft_config(task_type, peft_config) | ||
|
|
||
|
|
@@ -218,10 +232,6 @@ def train( | |
| "Validation dataset length is %s", len(formatted_validation_dataset) | ||
| ) | ||
|
|
||
| aim_callback = get_aimstack_callback() | ||
| file_logger_callback = FileLoggingCallback(logger) | ||
| callbacks = [aim_callback, file_logger_callback] | ||
|
|
||
| if train_args.packing: | ||
| logger.info("Packing is set to True") | ||
| data_collator = None | ||
|
|
@@ -261,6 +271,16 @@ def train( | |
| peft_config=peft_config, | ||
| ) | ||
|
|
||
| # We track additional metrics and experiment metadata after | ||
| # Trainer object creation to ensure that this is not repeated | ||
| # multiple times for FSDP runs. | ||
| if tracker is not None: | ||
| # Currently tracked only on process zero. | ||
| if trainer.is_world_process_zero(): | ||
| for k, v in additional_metrics.items(): | ||
| tracker.track(metric=v, name=k, stage="additional_metrics") | ||
| tracker.set_params(params=exp_metadata, name="experiment_metadata") | ||
|
|
||
| if trainer.is_fsdp_enabled and peft_config is not None: | ||
| trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( | ||
| model | ||
|
|
@@ -276,6 +296,7 @@ def main(**kwargs): # pylint: disable=unused-argument | |
| configs.TrainingArguments, | ||
| peft_config.LoraConfig, | ||
| peft_config.PromptTuningConfig, | ||
| tracker_configs.AimConfig, | ||
| ) | ||
| ) | ||
| parser.add_argument( | ||
|
|
@@ -284,22 +305,69 @@ def main(**kwargs): # pylint: disable=unused-argument | |
| choices=["pt", "lora", None, "none"], | ||
| default="pt", | ||
| ) | ||
| parser.add_argument( | ||
| "--exp_metadata", | ||
| type=str, | ||
| default=None, | ||
dushyantbehl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| help='Pass a json string representing K:V pairs to be associated to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'', | ||
| ) | ||
| ( | ||
| model_args, | ||
| data_args, | ||
| training_args, | ||
| lora_config, | ||
| prompt_tuning_config, | ||
| peft_method, | ||
| aim_config, | ||
| additional, | ||
| _, | ||
| ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) | ||
| if peft_method.peft_method == "lora": | ||
|
|
||
| peft_method = additional.peft_method | ||
| if peft_method == "lora": | ||
| tune_config = lora_config | ||
| elif peft_method.peft_method == "pt": | ||
| elif peft_method == "pt": | ||
| tune_config = prompt_tuning_config | ||
| else: | ||
| tune_config = None | ||
| train(model_args, data_args, training_args, tune_config) | ||
|
|
||
| tracker_name = training_args.tracker | ||
| if tracker_name == "aim": | ||
| tracker_config = aim_config | ||
| else: | ||
| tracker_config = None | ||
|
|
||
| # Initialize callbacks | ||
| file_logger_callback = FileLoggingCallback(logger) | ||
| callbacks = [file_logger_callback] | ||
|
|
||
| # Initialize the tracker | ||
| tracker = get_tracker(tracker_name, tracker_config) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not just have a
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. Here too it's a design choice we have to make. We have tried to keep it consistent with how config is managed in fms-hf-tuning here too. This has tried to follow the pattern used in the peft config block above. But if we choose to have a different approach we can follow that.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its simple choice to either perform things explicitly or run the code in some other funciton reverting to no callback under the hood. Does not affect the functionality. If you strongly feeling a meta function can help we can make that change. |
||
| tracker_callback = tracker.get_hf_callback() | ||
| if tracker_callback is not None: | ||
| callbacks.append(tracker_callback) | ||
|
|
||
| # extra metadata passed via client | ||
| metadata = None | ||
| if additional.exp_metadata is not None: | ||
| try: | ||
| metadata = json.loads(additional.exp_metadata) | ||
| if metadata is None or not isinstance(metadata, Dict): | ||
| logger.warning( | ||
| "metadata cannot be converted to simple k:v dict ignoring" | ||
| ) | ||
| metadata = None | ||
| except: | ||
| logger.error("failed while parsing extra metadata. pass a valid json") | ||
|
|
||
| train( | ||
| model_args=model_args, | ||
| data_args=data_args, | ||
| train_args=training_args, | ||
| peft_config=tune_config, | ||
| callbacks=callbacks, | ||
| tracker=tracker, | ||
| exp_metadata=metadata, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright The IBM Tuning Team | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| # Copyright The IBM Tuning Team | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Standard | ||
| import os | ||
|
|
||
| # Third Party | ||
| from aim.hugging_face import AimCallback | ||
|
|
||
| # Local | ||
| from .tracker import Tracker | ||
| from tuning.config.tracker_configs import AimConfig | ||
|
|
||
| class CustomAimCallback(AimCallback): | ||
|
|
||
| # A path to export run hash generated by Aim | ||
| # This is used to link back to the expriments from outside aimstack | ||
| hash_export_path = None | ||
|
|
||
| def on_init_end(self, args, state, control, **kwargs): | ||
|
|
||
| if state and not state.is_world_process_zero: | ||
| return | ||
|
|
||
| self.setup() # initializes the run_hash | ||
|
|
||
| # Store the run hash | ||
| # Change default run hash path to output directory | ||
| if self.hash_export_path is None: | ||
| if args and args.output_dir: | ||
| # args.output_dir/.aim_run_hash | ||
| self.hash_export_path = os.path.join( | ||
| args.output_dir, ".aim_run_hash" | ||
| ) | ||
|
|
||
| if self.hash_export_path: | ||
| with open(self.hash_export_path, "w") as f: | ||
| hash = self.experiment.hash | ||
| f.write('{"run_hash":"' + str(hash) + '"}\n') | ||
|
|
||
| def on_train_begin(self, args, state, control, model=None, **kwargs): | ||
| # call directly to make sure hyper parameters and model info is recorded. | ||
| self.setup(args=args, state=state, model=model) | ||
|
|
||
| class AimStackTracker(Tracker): | ||
| def __init__(self, tracker_config: AimConfig): | ||
| super().__init__(name="aim", tracker_config=tracker_config) | ||
|
|
||
| def get_hf_callback(self): | ||
| c = self.config | ||
| exp = c.experiment | ||
| ip = c.aim_remote_server_ip | ||
| port = c.aim_remote_server_port | ||
| repo = c.aim_repo | ||
| hash_export_path = c.aim_run_hash_export_path | ||
|
|
||
| if ip is not None and port is not None: | ||
| aim_callback = CustomAimCallback( | ||
| repo="aim://" + ip + ":" + port + "/", experiment=exp | ||
| ) | ||
| if repo: | ||
| aim_callback = CustomAimCallback(repo=repo, experiment=exp) | ||
| else: | ||
| aim_callback = CustomAimCallback(experiment=exp) | ||
|
|
||
| aim_callback.hash_export_path = hash_export_path | ||
| self.hf_callback = aim_callback | ||
| return self.hf_callback | ||
|
|
||
| def track(self, metric, name, stage="additional_metrics"): | ||
| context = {"subset": stage} | ||
| callback = self.hf_callback | ||
| run = callback.experiment | ||
| if run is not None: | ||
| run.track(metric, name=name, context=context) | ||
|
|
||
| def set_params(self, params, name="extra_params"): | ||
| try: | ||
| callback = self.hf_callback | ||
| run = callback.experiment | ||
| if run is not None: | ||
| [run.set((name, key), value, strict=False) for key, value in params.items()] | ||
| except Exception as e: | ||
| raise e |
Uh oh!
There was an error while loading. Please reload this page.