Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 0 additions & 39 deletions tuning/aim_loader.py

This file was deleted.

8 changes: 8 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,11 @@ class TrainingArguments(transformers.TrainingArguments):
default=False,
metadata={"help": "Packing to be enabled in SFT Trainer, default is False"},
)
tracker: str.lower = field(
default=None,
metadata={
"help": "Experiment tracker to use.\n" + \
"Available trackers are - aim, none\n" + \
"Requires additional configs, see tuning.configs/tracker_configs.py"
},
)
31 changes: 31 additions & 0 deletions tuning/config/tracker_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from dataclasses import dataclass

@dataclass
class AimConfig:
# Name of the experiment
experiment: str = None
# 'aim_repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server.
# When 'aim_remote_server_ip' or 'aim_remote_server_port' is set, it designates a remote aim repo.
# Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'.
# See https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html for documentation on Aim remote server tracking.
aim_repo: str = ".aim"
aim_remote_server_ip: str = None
aim_remote_server_port: int = None
# Location of where run_hash is exported, if unspecified this is output to
# training_args.output_dir/.aim_run_hash if the output_dir is set else not exported.
aim_run_hash_export_path: str = None
94 changes: 81 additions & 13 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

# Standard
from datetime import datetime
from typing import Optional, Union
from typing import Dict, List, Optional, Union
import json
import os
import time
import sys

# Third Party
Expand All @@ -37,12 +38,15 @@
import transformers

# Local
from tuning.aim_loader import get_aimstack_callback
from tuning.config import configs, peft_config
from tuning.config import configs, peft_config, tracker_configs
from tuning.data import tokenizer_data_utils
from tuning.trackers.tracker import Tracker
from tuning.trackers.tracker_factory import get_tracker
from tuning.utils.config_utils import get_hf_peft_config
from tuning.utils.data_type_utils import get_torch_dtype

logger = logging.get_logger("sft_trainer")


class FileLoggingCallback(TrainerCallback):
"""Exports metrics, e.g., training loss to a file in the checkpoint directory."""
Expand Down Expand Up @@ -95,6 +99,9 @@ def train(
peft_config: Optional[ # pylint: disable=redefined-outer-name
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
] = None,
callbacks: Optional[List[TrainerCallback]] = None,
tracker: Optional[Tracker] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the train function originall accepts only configurations by design, I feel we need a strong reason to allow it to also include objects. The more natural way would be to accept one "tracker config" input

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a yaml as config was the first thought that came too. But given HF has chosen to use arguments for most configurations we went with it as a design choice. But if in fms-hf-tuning we choose to use config files for configs, we can do that too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Train function expects to take callbacks which need to be associated with the training hence main takes the config and initializes tracker (which is needed for the callback)
Tracker here is the tracker object which train function needs to track any extra metrics and metadata hence the design choice to pass tracker separately.

exp_metadata: Optional[Dict] = None,
):
"""Call the SFTTrainer

Expand All @@ -106,10 +113,13 @@ def train(
peft_config.PromptTuningConfig for prompt tuning | \
None for fine tuning
The peft configuration to pass to trainer
callbacks: List of callbacks to attach with SFTtrainer.
tracker: One of the available trackers in tuning.trackers.tracker_factory.REGISTERED_TRACKERS
Initialized using tuning.trackers.tracker_factory.get_tracker
Using configs in tuning.config.tracker_configs
exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker.
"""

logger = logging.get_logger("sft_trainer")

# Validate parameters
if (not isinstance(train_args.num_train_epochs, float)) or (
train_args.num_train_epochs <= 0
Expand All @@ -121,12 +131,16 @@ def train(
raise ValueError("gradient_accumulation_steps has to be an integer >= 1")

task_type = "CAUSAL_LM"
additional_metrics = {}

model_load_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=train_args.cache_dir,
torch_dtype=get_torch_dtype(model_args.torch_dtype),
use_flash_attention_2=model_args.use_flash_attn,
)
additional_metrics["model_load_time"] = time.time() - model_load_time

peft_config = get_hf_peft_config(task_type, peft_config)

Expand Down Expand Up @@ -218,10 +232,6 @@ def train(
"Validation dataset length is %s", len(formatted_validation_dataset)
)

aim_callback = get_aimstack_callback()
file_logger_callback = FileLoggingCallback(logger)
callbacks = [aim_callback, file_logger_callback]

if train_args.packing:
logger.info("Packing is set to True")
data_collator = None
Expand Down Expand Up @@ -261,6 +271,16 @@ def train(
peft_config=peft_config,
)

# We track additional metrics and experiment metadata after
# Trainer object creation to ensure that this is not repeated
# multiple times for FSDP runs.
if tracker is not None:
# Currently tracked only on process zero.
if trainer.is_world_process_zero():
for k, v in additional_metrics.items():
tracker.track(metric=v, name=k, stage="additional_metrics")
tracker.set_params(params=exp_metadata, name="experiment_metadata")

if trainer.is_fsdp_enabled and peft_config is not None:
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(
model
Expand All @@ -276,6 +296,7 @@ def main(**kwargs): # pylint: disable=unused-argument
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
tracker_configs.AimConfig,
)
)
parser.add_argument(
Expand All @@ -284,22 +305,69 @@ def main(**kwargs): # pylint: disable=unused-argument
choices=["pt", "lora", None, "none"],
default="pt",
)
parser.add_argument(
"--exp_metadata",
type=str,
default=None,
help='Pass a json string representing K:V pairs to be associated to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'',
)
(
model_args,
data_args,
training_args,
lora_config,
prompt_tuning_config,
peft_method,
aim_config,
additional,
_,
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if peft_method.peft_method == "lora":

peft_method = additional.peft_method
if peft_method == "lora":
tune_config = lora_config
elif peft_method.peft_method == "pt":
elif peft_method == "pt":
tune_config = prompt_tuning_config
else:
tune_config = None
train(model_args, data_args, training_args, tune_config)

tracker_name = training_args.tracker
if tracker_name == "aim":
tracker_config = aim_config
else:
tracker_config = None

# Initialize callbacks
file_logger_callback = FileLoggingCallback(logger)
callbacks = [file_logger_callback]

# Initialize the tracker
tracker = get_tracker(tracker_name, tracker_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just have a tracker_install_callbacks function somewhere to reduce 3 lines into 1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Here too it's a design choice we have to make. We have tried to keep it consistent with how config is managed in fms-hf-tuning here too. This has tried to follow the pattern used in the peft config block above. But if we choose to have a different approach we can follow that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its simple choice to either perform things explicitly or run the code in some other funciton reverting to no callback under the hood. Does not affect the functionality. If you strongly feeling a meta function can help we can make that change.

tracker_callback = tracker.get_hf_callback()
if tracker_callback is not None:
callbacks.append(tracker_callback)

# extra metadata passed via client
metadata = None
if additional.exp_metadata is not None:
try:
metadata = json.loads(additional.exp_metadata)
if metadata is None or not isinstance(metadata, Dict):
logger.warning(
"metadata cannot be converted to simple k:v dict ignoring"
)
metadata = None
except:
logger.error("failed while parsing extra metadata. pass a valid json")

train(
model_args=model_args,
data_args=data_args,
train_args=training_args,
peft_config=tune_config,
callbacks=callbacks,
tracker=tracker,
exp_metadata=metadata,
)


if __name__ == "__main__":
Expand Down
13 changes: 13 additions & 0 deletions tuning/trackers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
95 changes: 95 additions & 0 deletions tuning/trackers/aimstack_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
import os

# Third Party
from aim.hugging_face import AimCallback

# Local
from .tracker import Tracker
from tuning.config.tracker_configs import AimConfig

class CustomAimCallback(AimCallback):

# A path to export run hash generated by Aim
# This is used to link back to the expriments from outside aimstack
hash_export_path = None

def on_init_end(self, args, state, control, **kwargs):

if state and not state.is_world_process_zero:
return

self.setup() # initializes the run_hash

# Store the run hash
# Change default run hash path to output directory
if self.hash_export_path is None:
if args and args.output_dir:
# args.output_dir/.aim_run_hash
self.hash_export_path = os.path.join(
args.output_dir, ".aim_run_hash"
)

if self.hash_export_path:
with open(self.hash_export_path, "w") as f:
hash = self.experiment.hash
f.write('{"run_hash":"' + str(hash) + '"}\n')

def on_train_begin(self, args, state, control, model=None, **kwargs):
# call directly to make sure hyper parameters and model info is recorded.
self.setup(args=args, state=state, model=model)

class AimStackTracker(Tracker):
def __init__(self, tracker_config: AimConfig):
super().__init__(name="aim", tracker_config=tracker_config)

def get_hf_callback(self):
c = self.config
exp = c.experiment
ip = c.aim_remote_server_ip
port = c.aim_remote_server_port
repo = c.aim_repo
hash_export_path = c.aim_run_hash_export_path

if ip is not None and port is not None:
aim_callback = CustomAimCallback(
repo="aim://" + ip + ":" + port + "/", experiment=exp
)
if repo:
aim_callback = CustomAimCallback(repo=repo, experiment=exp)
else:
aim_callback = CustomAimCallback(experiment=exp)

aim_callback.hash_export_path = hash_export_path
self.hf_callback = aim_callback
return self.hf_callback

def track(self, metric, name, stage="additional_metrics"):
context = {"subset": stage}
callback = self.hf_callback
run = callback.experiment
if run is not None:
run.track(metric, name=name, context=context)

def set_params(self, params, name="extra_params"):
try:
callback = self.hf_callback
run = callback.experiment
if run is not None:
[run.set((name, key), value, strict=False) for key, value in params.items()]
except Exception as e:
raise e
Loading