From 8be8273e0a101b6cef53e4780e9b62d9f3b495ba Mon Sep 17 00:00:00 2001 From: Anh-Uong Date: Wed, 19 Jun 2024 15:21:10 -0600 Subject: [PATCH 1/7] deps: pin transformers at v4.40 - update unit tests with old evaluation_strategy flag Signed-off-by: Anh-Uong --- pyproject.toml | 2 +- tests/test_sft_trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 052de5ee0..92f61ae16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers=[ dependencies = [ "numpy", "accelerate>=0.20.3", -"transformers", +"transformers>=4.34.1,<4.41", "torch", "sentencepiece", "tokenizers>=0.13.3", diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 19db3402e..ffe32f436 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -301,7 +301,7 @@ def test_run_causallm_pt_with_validation(): with tempfile.TemporaryDirectory() as tempdir: train_args = copy.deepcopy(TRAIN_ARGS) train_args.output_dir = tempdir - train_args.eval_strategy = "epoch" + train_args.evaluation_strategy = "epoch" data_args = copy.deepcopy(DATA_ARGS) data_args.validation_data_path = TWITTER_COMPLAINTS_DATA @@ -314,7 +314,7 @@ def test_run_causallm_pt_with_validation_data_formatting(): with tempfile.TemporaryDirectory() as tempdir: train_args = copy.deepcopy(TRAIN_ARGS) train_args.output_dir = tempdir - train_args.eval_strategy = "epoch" + train_args.evaluation_strategy = "epoch" data_args = copy.deepcopy(DATA_ARGS) data_args.validation_data_path = TWITTER_COMPLAINTS_DATA data_args.dataset_text_field = None From 49305ad2d13ea6030cd4ede68e5ffb59a1e5f4a2 Mon Sep 17 00:00:00 2001 From: Will <49654846+willmj@users.noreply.github.com> Date: Wed, 14 Aug 2024 18:18:36 -0400 Subject: [PATCH 2/7] release: merge set of changes for v1.2.0 (#299) * Set default value of target_modules to be None in LoraConfig Signed-off-by: Will Johnson * Removal of transformers logger and addition of python logger Signed-off-by: Abhishek * FMT and lint check: Removal of transformers logger and addition of python logger Signed-off-by: Abhishek * fix: remove lm_head for granite with llama arch models (#258) * initial code for deleting lm_head Signed-off-by: Anh-Uong * fix logic for copying checkpoint Signed-off-by: Anh-Uong * fix check that embed_tokens and lm_head weights are the same Signed-off-by: Anh-Uong * fix warning assertion Signed-off-by: Anh-Uong * fix lm_head check, remove test Signed-off-by: Anh-Uong * small fixes from code review Signed-off-by: Anh-Uong * fmt Signed-off-by: Anh-Uong --------- Signed-off-by: Anh-Uong Co-authored-by: Anh-Uong Signed-off-by: Abhishek * Add config_utils tests Signed-off-by: Angel Luu * Fix fmt Signed-off-by: Angel Luu * Separate tests out and use docstrings Signed-off-by: Angel Luu * Update more field/value checks from HF defaults Signed-off-by: Angel Luu * Fix: Addition of env var TRANSFORMERS_VERBOSITY check Signed-off-by: Abhishek * FMT Fix: Addition of env var TRANSFORMERS_VERBOSITY check Signed-off-by: Abhishek * Add test for tokenizer in lora config (should be ignored) Signed-off-by: Angel Luu * Adding logging support to accelerate launch Signed-off-by: Abhishek * FMT_FIX: Adding logging support to accelerate launch Signed-off-by: Abhishek * bug: On save event added to callback (#256) * feat: On save event added to callback Signed-off-by: Padmanabha V Seshadri * fix: Removed additional bracket Signed-off-by: Padmanabha V Seshadri * fix: Removed additional bracket Signed-off-by: Padmanabha V Seshadri * fix: Format issues resolved Signed-off-by: Padmanabha V Seshadri * fix: rebase with upstream and add new line Signed-off-by: Mehant Kammakomati --------- Signed-off-by: Padmanabha V Seshadri Signed-off-by: Mehant Kammakomati Co-authored-by: Mehant Kammakomati * feat: All metric handling changes (#263) * feat: All metric handling changes Signed-off-by: Padmanabha V Seshadri * fix: Format issues Signed-off-by: Padmanabha V Seshadri --------- Signed-off-by: Padmanabha V Seshadri * feat: Configuration to set logging level for trigger log (#241) * feat: Added the triggered login in the operation Signed-off-by: Padmanabha V Seshadri * fix: Formatting issues Signed-off-by: Padmanabha V Seshadri * fix: Added default config Signed-off-by: Padmanabha V Seshadri * fix: Moved the variable to right scope Signed-off-by: Padmanabha V Seshadri * fix: Checked added to validate config log level Signed-off-by: Padmanabha V Seshadri * fix: Removed some unwanted log file Signed-off-by: Padmanabha V Seshadri --------- Signed-off-by: Padmanabha V Seshadri * limit peft deps until investigate (#274) Signed-off-by: Anh-Uong * Data custom collator (#260) * refactor code to preprocess datasets Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * fix formatting Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * allow input/output in validate args Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * format input/output JSON and mask Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * function to return suitable collator Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * add tests for SFT Trainer input/output format Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * remove unused functions Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * add eos token to input/output format Signed-off-by: Sukriti-Sharma4 * fix tests Signed-off-by: Sukriti-Sharma4 * improve docstrings Signed-off-by: Sukriti-Sharma4 * keeping JSON keys constant Signed-off-by: Sukriti-Sharma4 * support for input/output format Signed-off-by: Sukriti-Sharma4 * formatting fixes Signed-off-by: Sukriti-Sharma4 * update rEADME formats Signed-off-by: Sukriti-Sharma4 * formatting README Signed-off-by: Sukriti-Sharma4 --------- Signed-off-by: Sukriti-Sharma4 Co-authored-by: Alex-Brooks * Revert "limit peft deps until investigate (#274)" (#275) This reverts commit f57ff63650ba139d6e0471d244df4a70e4b13d0b. Signed-off-by: Anh-Uong * feat: per process state metric (#239) Signed-off-by: Harikrishnan Balagopal * Modify test to pass with target_modules: None Signed-off-by: Will Johnson * Logging changes and unit tests added Signed-off-by: Abhishek * feat: Add a dockerfile argument to enable aimstack (#261) * Add a dockerfile argument at the end of final layer to enable aimstack. Currenlty guarded by a dockerfile argument. Signed-off-by: Dushyant Behl * Set the default value of ENABLE_AIM to false Signed-off-by: Dushyant Behl --------- Signed-off-by: Dushyant Behl * Solved conflict with main Signed-off-by: Abhishek * FMT:Fix Solved conflict with main Signed-off-by: Abhishek * enabling tests for prompt tuning Signed-off-by: Abhishek * feat: Support pretokenized (#272) * feat: support pretokenized datasets Signed-off-by: Mehant Kammakomati * fix: rebase with upstream and review commits Signed-off-by: Mehant Kammakomati * fix: rebase with upstream and review commits Signed-off-by: Mehant Kammakomati * fix: rebase with upstream and review commits Signed-off-by: Mehant Kammakomati * consolidate collator code Signed-off-by: Sukriti-Sharma4 * add valuerrors for incorrect args Signed-off-by: Sukriti-Sharma4 * feat: add unit tests for validate_data_args and format_dataset Signed-off-by: Mehant Kammakomati * feat: add unit tests for validate_data_args and format_dataset Signed-off-by: Mehant Kammakomati * feat: add unit tests for validate_data_args and format_dataset Signed-off-by: Mehant Kammakomati * feat: add unit tests for validate_data_args and format_dataset Signed-off-by: Mehant Kammakomati --------- Signed-off-by: Mehant Kammakomati Signed-off-by: Sukriti-Sharma4 Co-authored-by: Sukriti-Sharma4 Co-authored-by: Alex Brooks * Update packaging requirement from <24,>=23.2 to >=23.2,<25 (#212) Updates the requirements on [packaging](https://github.com/pypa/packaging) to permit the latest version. - [Release notes](https://github.com/pypa/packaging/releases) - [Changelog](https://github.com/pypa/packaging/blob/main/CHANGELOG.rst) - [Commits](https://github.com/pypa/packaging/compare/23.2...24.1) --- updated-dependencies: - dependency-name: packaging dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Anh Uong * enabling tests for prompt tuning (#278) Signed-off-by: Abhishek Co-authored-by: Anh Uong * fix: do not add special tokens for custom tokenizer (#279) Signed-off-by: Mehant Kammakomati * PR changes for changing logger Signed-off-by: Abhishek * fix: bug where the logger was not being used properly (#286) Signed-off-by: Hari * Unit Tests changes Signed-off-by: Abhishek * Add functionality to free disk space from Github Actions (#287) * Add functionality to free disk space from Github Actions Signed-off-by: Will Johnson * Add functionality to free disk space from Github Actions, relocate from build-and-publish.yaml to image.yaml Signed-off-by: Will Johnson * Move freeing space step to before building image Signed-off-by: Will Johnson --------- Signed-off-by: Will Johnson * commented os.environ[LOG_LEVEL] in accelerate.py for testing Signed-off-by: Abhishek * PR changes Signed-off-by: Abhishek * FIX:FMT Signed-off-by: Abhishek * PR Changes Signed-off-by: Abhishek * PR Changes Signed-off-by: Abhishek * Add unit test to verify target_modules defaults correctly (#281) * Add unit test to verify target_modules defaults correctly Signed-off-by: Will Johnson * Add sft_trainer.main test to ensure target modules properly default for LoRA when set to None from CLI Signed-off-by: Will Johnson * fmt Signed-off-by: Will Johnson * Use model_args instead of importing, fix nits Signed-off-by: Will Johnson * Add test to ensure target_modules defaults to None in job config Signed-off-by: Will Johnson * Add additional check, fix nits Signed-off-by: Will Johnson --------- Signed-off-by: Will Johnson * docs: Add documentation on experiment tracking. (#257) Signed-off-by: Dushyant Behl * Ensure additional metadata to trackers don't throw error in happy case. (#290) Signed-off-by: Dushyant Behl * PR Changes Signed-off-by: Abhishek * fix multiple runid creation bug with accelerate. (#268) Signed-off-by: Dushyant Behl * feat: logging control operation (#264) Signed-off-by: Padmanabha V Seshadri * Metrics file epoch indexing from 0 Signed-off-by: Abhishek * Revert last commit Signed-off-by: Abhishek * fix run evaluation to get base model path (#273) Signed-off-by: Anh-Uong * PR Changes Signed-off-by: Abhishek * PR Changes Signed-off-by: Abhishek * feat: Added additional events such as on_step_begin, on_optimizer_step, on_substep_end (#293) Signed-off-by: Padmanabha V Seshadri * Always update setuptools to latest (#288) Signed-off-by: James Busche Co-authored-by: Anh Uong * Rename all fixtures with correct .jsonl extension (#295) Signed-off-by: Will Johnson Co-authored-by: Anh Uong * feat: add save_model_dir flag where final checkpoint saved (#291) * add save_model_dir flag for final checkpoint Signed-off-by: Anh-Uong * remove output_dir logic, add save method Signed-off-by: Anh-Uong * update accelerate_launch, remove save tokenizer Signed-off-by: Anh-Uong * fix: put back creation of .complete file Signed-off-by: Anh-Uong * fix failing tests and add new ones Signed-off-by: Anh-Uong * tests: add sft_trainer test to train and save - small refactor of tests Signed-off-by: Anh-Uong * add docs on saving checkpoints and fix help msg Signed-off-by: Anh-Uong * update example and note best checkpoint Signed-off-by: Anh-Uong * changes based on PR review Signed-off-by: Anh-Uong * add logging to save, fix error out properly Signed-off-by: Anh-Uong --------- Signed-off-by: Anh-Uong --------- Signed-off-by: Will Johnson Signed-off-by: Abhishek Signed-off-by: Anh-Uong Signed-off-by: Angel Luu Signed-off-by: Padmanabha V Seshadri Signed-off-by: Mehant Kammakomati Signed-off-by: Sukriti-Sharma4 Signed-off-by: Harikrishnan Balagopal Signed-off-by: Dushyant Behl Signed-off-by: dependabot[bot] Signed-off-by: Hari Signed-off-by: James Busche Co-authored-by: Abhishek Co-authored-by: Sukriti Sharma Co-authored-by: Anh-Uong Co-authored-by: Abhishek Maurya <124327945+Abhishek-TAMU@users.noreply.github.com> Co-authored-by: Angel Luu Co-authored-by: Angel Luu Co-authored-by: Padmanabha V Seshadri Co-authored-by: Mehant Kammakomati Co-authored-by: Alex-Brooks Co-authored-by: Hari Co-authored-by: Dushyant Behl Co-authored-by: Sukriti-Sharma4 Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: James Busche --- .github/workflows/image.yaml | 10 +- README.md | 100 +++++- build/Dockerfile | 21 +- build/accelerate_launch.py | 242 +++++++-------- docs/experiment-tracking.md | 133 ++++++++ ...ining-loss-below-threshold-log-config.yaml | 14 + .../log_controller.yaml | 16 + examples/trainercontroller_configs/loss.yaml | 4 +- pyproject.toml | 2 +- scripts/run_evaluation.py | 19 +- tests/build/test_launch_script.py | 132 +++++++- tests/data/__init__.py | 7 +- tests/data/trainercontroller/__init__.py | 2 + .../trainercontroller/log_controller.yaml | 16 + .../loss_custom_operation.yaml | 4 +- .../loss_custom_operation_invalid_action.yaml | 4 +- .../loss_invalid_metric.yaml | 4 +- .../loss_invalid_operation.yaml | 4 +- .../loss_invalid_operation_action.yaml | 4 +- .../loss_invalid_trigger.yaml | 4 +- .../trainercontroller/loss_on_threshold.yaml | 4 +- .../loss_on_threshold_with_trainer_state.yaml | 6 +- .../loss_unavailable_metric.yaml | 4 +- .../loss_with_invalid_type_rule.yaml | 2 +- .../loss_with_malicious_input_rule.yaml | 2 +- .../loss_with_malicious_os_rule.yaml | 2 +- tests/data/trainercontroller/on-save.yaml | 10 + ... => twitter_complaints_input_output.jsonl} | 0 ...ll.json => twitter_complaints_small.jsonl} | 0 ..._tokenized_with_maykeye_tinyllama_v0.jsonl | 10 + tests/test_sft_trainer.py | 214 +++++++++++-- tests/trackers/test_aim_tracker.py | 3 +- tests/trackers/test_file_logging_tracker.py | 3 +- tests/trainercontroller/custom_operation.py | 7 - .../custom_operation_invalid_action.py | 7 - .../test_tuning_trainercontroller.py | 33 ++ tests/utils/test_config_utils.py | 234 ++++++++++++++ tests/utils/test_logging.py | 84 +++++ tests/utils/test_preprocessing_utils.py | 290 +++++++++++++----- tuning/config/configs.py | 29 +- tuning/config/peft_config.py | 2 +- tuning/sft_trainer.py | 149 ++++++--- tuning/trackers/aimstack_tracker.py | 15 +- tuning/trackers/filelogging_tracker.py | 5 +- tuning/trackers/tracker_factory.py | 15 +- tuning/trainercontroller/callback.py | 88 +++++- .../controllermetrics/__init__.py | 2 + .../controllermetrics/eval_metrics.py | 5 - .../history_based_metrics.py | 2 - .../controllermetrics/loss.py | 2 +- .../controllermetrics/metrics.yaml | 9 + .../controllermetrics/per_process_state.py | 77 +++++ .../controllermetrics/trainingstate.py | 5 +- .../trainercontroller/operations/__init__.py | 2 + .../operations/hfcontrols.py | 8 +- .../operations/logcontrol.py | 55 ++++ .../trainercontroller/operations/operation.py | 40 ++- tuning/trainercontroller/patience.py | 10 +- tuning/utils/data_type_utils.py | 6 +- tuning/utils/logging.py | 64 ++++ tuning/utils/preprocessing_utils.py | 279 ++++++++--------- 61 files changed, 1963 insertions(+), 563 deletions(-) create mode 100644 docs/experiment-tracking.md create mode 100644 examples/trainercontroller_configs/epoch-level-training-loss-below-threshold-log-config.yaml create mode 100644 examples/trainercontroller_configs/log_controller.yaml create mode 100644 tests/data/trainercontroller/log_controller.yaml create mode 100644 tests/data/trainercontroller/on-save.yaml rename tests/data/{twitter_complaints_input_output.json => twitter_complaints_input_output.jsonl} (100%) rename tests/data/{twitter_complaints_small.json => twitter_complaints_small.jsonl} (100%) create mode 100644 tests/data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl create mode 100644 tests/utils/test_config_utils.py create mode 100644 tests/utils/test_logging.py create mode 100644 tuning/trainercontroller/controllermetrics/metrics.yaml create mode 100644 tuning/trainercontroller/controllermetrics/per_process_state.py create mode 100644 tuning/trainercontroller/operations/logcontrol.py create mode 100644 tuning/utils/logging.py diff --git a/.github/workflows/image.yaml b/.github/workflows/image.yaml index 73bbe5982..d4d836bec 100644 --- a/.github/workflows/image.yaml +++ b/.github/workflows/image.yaml @@ -10,6 +10,14 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + - name: Free disk space + run: | + sudo swapoff -a + sudo rm -f /swapfile + sudo apt clean + docker rmi $(docker image ls -aq) + df -h - name: Build image run: | - docker build -t fms-hf-tuning:dev . -f build/Dockerfile \ No newline at end of file + docker build -t fms-hf-tuning:dev . -f build/Dockerfile + \ No newline at end of file diff --git a/README.md b/README.md index ac65f1c24..4bc762e4d 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ - [Training](#training) - [Single GPU](#single-gpu) - [Multiple GPUs with FSDP](#multiple-gpus-with-fsdp) + - [Tips on Parameters to Set](#tips-on-parameters-to-set) - [Tuning Techniques](#tuning-techniques) - [LoRA Tuning Example](#lora-tuning-example) - [Prompt Tuning](#prompt-tuning) @@ -18,6 +19,7 @@ - [Changing the Base Model for Inference](#changing-the-base-model-for-inference) - [Validation](#validation) - [Trainer Controller Framework](#trainer-controller-framework) +- [Experiment Tracking](#experiment-tracking) - [More Examples](#more-examples) This repo provides basic tuning scripts with support for specific models. The repo relies on Hugging Face `SFTTrainer` and PyTorch FSDP. Our approach to tuning is: @@ -27,10 +29,14 @@ This repo provides basic tuning scripts with support for specific models. The re ## Installation +### Basic Installation + ``` pip install fms-hf-tuning ``` +### Using FlashAttention + > Note: After installing, if you wish to use [FlashAttention](https://github.com/Dao-AILab/flash-attention), then you need to install these requirements: ``` pip install fms-hf-tuning[dev] @@ -38,21 +44,28 @@ pip install fms-hf-tuning[flash-attn] ``` [FlashAttention](https://github.com/Dao-AILab/flash-attention) requires the [CUDA Toolit](https://developer.nvidia.com/cuda-toolkit) to be pre-installed. -If you wish to use [aim](https://github.com/aimhubio/aim), then you need to install it: -``` -pip install fms-hf-tuning[aim] -``` +### Using FMS-Acceleration If you wish to use [fms-acceleration](https://github.com/foundation-model-stack/fms-acceleration), you need to install it. ``` pip install fms-hf-tuning[fms-accel] ``` -`fms-acceleration` is a collection of plugins that packages that accelerate fine-tuning / training of large models, as part of the `fms-hf-tuning` suite. For more details on see [this section below](#fms-acceleration). +`fms-acceleration` is a collection of plugins that packages that accelerate fine-tuning / training of large models, as part of the `fms-hf-tuning` suite. For more details see [this section below](#fms-acceleration). + +### Using Experiment Trackers + +To use experiment tracking with popular tools like [Aim](https://github.com/aimhubio/aim), note that some trackers are considered optional dependencies and can be installed with the following command: +``` +pip install fms-hf-tuning[aim] +``` +For more details on how to enable and use the trackers, Please see, [the experiment tracking section below](#experiment-tracking). ## Data format -We support two data formats: +We support the following data formats: + +### 1. JSON formats with a single sequence and a specified response_template to use for masking on completion. -1. #### Pre-process the JSON/JSONL dataset +#### 1.1 Pre-process the JSON/JSONL dataset Pre-process the JSON/JSONL dataset to contain a single sequence of each data instance containing input + Response. The trainer is configured to expect a response template as a string. For example, if one wants to prepare the `alpaca` format data to feed into this trainer, it is quite easy and can be done with the following code. ```python @@ -87,7 +100,7 @@ The same way can be applied to any dataset, with more info can be found [here](h Once the JSON is converted using the formatting function, pass the `dataset_text_field` containing the single sequence to the trainer. -2. #### Format JSON/JSONL on the fly +#### 1.2 Format JSON/JSONL on the fly Pass a JSON/JSONL and a `data_formatter_template` to use the formatting function on the fly while tuning. The template should specify fields of JSON with `{{field}}`. While tuning, the data will be converted to a single sequence using the template. JSON fields can contain alpha-numeric characters, spaces and the following special symbols - "." , "_", "-". @@ -101,8 +114,20 @@ data_formatter_template: `### Input: {{input}} \n\n##Label: {{output}}` Formatting will happen on the fly while tuning. The keys in template should match fields in JSON file. The `response template` corresponding to the above template will need to be supplied. in this case, `response template` = `\n## Label:`. +##### In conclusion, if using the reponse_template and single sequence, either the `data_formatter_template` argument or `dataset_text_field` needs to be supplied to the trainer. -##### In conclusion, either the `data_formatter_template` argument or `dataset_text_field` needs to be supplied to the trainer. +### 2. JSONL with input and output fields (no response template) + + Pass a JSONL containing fields "input" with source text and "output" with class labels. Pre-format the input as you see fit. The output field will simply be concatenated to the end of input to create single sequence, and input will be masked. + + The "input" and "output" field names are mandatory and cannot be changed. + +Example: Train.jsonl + +``` +{"input": "### Input: Colorado is a state in USA ### Output:", "output": "USA : Location"} +{"input": "### Input: Arizona is also a state in USA ### Output:", "output": "USA : Location"} +``` ## Supported Models @@ -201,6 +226,50 @@ tuning/sft_trainer.py \ To summarize you can pick either python for single-GPU jobs or use accelerate launch for multi-GPU jobs. The following tuning techniques can be applied: +### Tips on Parameters to Set + +#### Saving checkpoints while training + +By default, [`save_strategy`](tuning/config/configs.py) is set to `"epoch"` in the TrainingArguments. This means that checkpoints will be saved on each epoch. This can also be set to `"steps"` to save on every `"save_steps"` or `"no"` to not save any checkpoints. + +Checkpoints are saved to the given `output_dir`, which is a required field. If `save_strategy="no"`, the `output_dir` will only contain the training logs with loss details. + +A useful flag to set to limit the number of checkpoints saved is [`save_total_limit`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_total_limit). Older checkpoints are deleted from the `output_dir` to limit the number of checkpoints, for example, if `save_total_limit=1`, this will only save the last checkpoint. However, while tuning, two checkpoints will exist in `output_dir` for a short time as the new checkpoint is created and then the older one will be deleted. If the user sets a validation dataset and [`load_best_model_at_end`](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments.load_best_model_at_end), then the best checkpoint will be saved. + +#### Saving model after training + +`save_model_dir` can optionally be set to save the tuned model using `SFTTrainer.save_model()`. This can be used in tandem with `save_strategy="no"` to only save the designated checkpoint and not any intermediate checkpoints, which can help to save space. + +`save_model_dir` can be set to a different directory than `output_dir`. If set to the same directory, the designated checkpoint, training logs, and any intermediate checkpoints will all be saved to the same directory as seen below. + +
+Ways you can use `save_model_dir` and more tips: + +For example, if `save_model_dir` is set to a sub-directory of `output_dir`and `save_total_limit=1` with LoRA tuning, the directory would look like: + +```sh +$ ls /tmp/output_dir/ +checkpoint-35 save_model_dir training_logs.jsonl + +$ ls /tmp/output_dir/save_model_dir/ +README.md adapter_model.safetensors special_tokens_map.json tokenizer.model training_args.bin +adapter_config.json added_tokens.json tokenizer.json tokenizer_config.json +``` + +Here is an fine tuning example of how the directory would look if `output_dir` is set to the same value as `save_model_dir` and `save_total_limit=2`. Note the checkpoint directories as well as the `training_logs.jsonl`: + +```sh +$ ls /tmp/same_dir + +added_tokens.json model-00001-of-00006.safetensors model-00006-of-00006.safetensors tokenizer_config.json +checkpoint-16 model-00002-of-00006.safetensors model.safetensors.index.json training_args.bin +checkpoint-20 model-00003-of-00006.safetensors special_tokens_map.json training_logs.jsonl +config.json model-00004-of-00006.safetensors tokenizer.json +generation_config.json model-00005-of-00006.safetensors tokenizer.model +``` + +
+ ## Tuning Techniques: ### LoRA Tuning Example @@ -549,6 +618,19 @@ This framework helps users define rules to capture scenarios like criteria for s For details about how you can use set a custom stopping criteria and perform custom operations, see [examples/trainercontroller_configs/Readme.md](examples/trainercontroller_configs/Readme.md) + +## Experiment Tracking + +Experiment tracking in fms-hf-tuning allows users to track their experiments with known trackers like [Aimstack](https://aimstack.io/) or custom trackers built into the code like +[FileLoggingTracker](./tuning/trackers/filelogging_tracker.py) + +The code supports currently two trackers out of the box, +* `FileLoggingTracker` : A built in tracker which supports logging training loss to a file. +* `Aimstack` : A popular opensource tracker which can be used to track any metrics or metadata from the experiments. + +Further details on enabling and using the trackers mentioned above can be found [here](docs/experiment-tracking.md). + + ## More Examples [Prompt Tuning on Twitter Complaints](examples/prompt_tuning_twitter_complaints/README.md) diff --git a/build/Dockerfile b/build/Dockerfile index 6ccd21e96..de00ddaee 100644 --- a/build/Dockerfile +++ b/build/Dockerfile @@ -33,6 +33,7 @@ RUN dnf remove -y --disableplugin=subscription-manager \ && ln -s /usr/bin/python${PYTHON_VERSION} /bin/python \ && python -m ensurepip --upgrade \ && python -m pip install --upgrade pip \ + && python -m pip install --upgrade setuptools \ && dnf update -y \ && dnf clean all @@ -104,6 +105,9 @@ ARG WHEEL_VERSION ARG USER ARG USER_UID +## Enable Aimstack if requested via ENABLE_AIM set to "true" +ARG ENABLE_AIM=false + RUN dnf install -y git && \ # perl-Net-SSLeay.x86_64 and server_key.pem are installed with git as dependencies # Twistlock detects it as H severity: Private keys stored in image @@ -129,6 +133,9 @@ RUN --mount=type=cache,target=/home/${USER}/.cache/pip,uid=${USER_UID} \ python -m pip install --user wheel && \ python -m pip install --user "$(head bdist_name)" && \ python -m pip install --user "$(head bdist_name)[flash-attn]" && \ + if [[ "${ENABLE_AIM}" == "true" ]]; then \ + python -m pip install --user "$(head bdist_name)[aim]"; \ + fi && \ # Clean up the wheel module. It's only needed by flash-attn install python -m pip uninstall wheel build -y && \ # Cleanup the bdist whl file @@ -146,6 +153,14 @@ RUN mkdir /app && \ chown -R $USER:0 /app /tmp && \ chmod -R g+rwX /app /tmp +# Need a better way to address these hacks +RUN if [[ "${ENABLE_AIM}" == "true" ]] ; then \ + touch /.aim_profile && \ + chmod -R 777 /.aim_profile; \ + fi +RUN mkdir /.cache && \ + chmod -R 777 /.cache + # Copy scripts and default configs COPY build/accelerate_launch.py fixtures/accelerate_fsdp_defaults.yaml /app/ COPY build/utils.py /app/build/ @@ -154,12 +169,6 @@ RUN chmod +x /app/accelerate_launch.py ENV FSDP_DEFAULTS_FILE_PATH="/app/accelerate_fsdp_defaults.yaml" ENV SET_NUM_PROCESSES_TO_NUM_GPUS="True" -# Need a better way to address this hack -RUN touch /.aim_profile && \ - chmod -R 777 /.aim_profile && \ - mkdir /.cache && \ - chmod -R 777 /.cache - WORKDIR /app USER ${USER} COPY --from=python-installations /home/${USER}/.local /home/${USER}/.local diff --git a/build/accelerate_launch.py b/build/accelerate_launch.py index ee8718b5d..d7753728c 100644 --- a/build/accelerate_launch.py +++ b/build/accelerate_launch.py @@ -23,8 +23,6 @@ import subprocess import sys import traceback -import tempfile -import shutil from pathlib import Path import json @@ -37,12 +35,9 @@ # Local from build.utils import ( process_accelerate_launch_args, - serialize_args, get_highest_checkpoint, - copy_checkpoint, ) from tuning.utils.config_utils import get_json_config -from tuning.config.tracker_configs import FileLoggingTrackerConfig from tuning.utils.error_logging import ( write_termination_log, USER_ERROR_EXIT_CODE, @@ -61,9 +56,6 @@ def get_base_model_from_adapter_config(adapter_config): def main(): - LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper() - logging.basicConfig(level=LOGLEVEL) - if not os.getenv("TERMINATION_LOG_FILE"): os.environ["TERMINATION_LOG_FILE"] = ERROR_LOG @@ -80,6 +72,18 @@ def main(): or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'." ) + # Configure log_level of python native logger. + # CLI arg takes precedence over env var. And if neither is set, we use default "WARNING" + log_level = job_config.get( + "log_level" + ) # this will be set to either the value found or None + if ( + not log_level + ): # if log level not set by job_config aka by JSON, set it via env var or set default + log_level = os.environ.get("LOG_LEVEL", "WARNING") + log_level = log_level.upper() + logging.basicConfig(level=log_level) + args = process_accelerate_launch_args(job_config) logging.debug("accelerate launch parsed args: %s", args) except FileNotFoundError as e: @@ -102,143 +106,111 @@ def main(): # Launch training # ########## - original_output_dir = job_config.get("output_dir") - with tempfile.TemporaryDirectory() as tempdir: - try: - # checkpoints outputted to tempdir, only final checkpoint copied to output dir - job_config["output_dir"] = tempdir - updated_args = serialize_args(job_config) - os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = updated_args - - launch_command(args) - except subprocess.CalledProcessError as e: - # If the subprocess throws an exception, the base exception is hidden in the - # subprocess call and is difficult to access at this level. However, that is not - # an issue because sft_trainer.py would have already written the exception - # message to termination log. - logging.error(traceback.format_exc()) - # The exit code that sft_trainer.py threw is captured in e.returncode - - return_code = e.returncode - if return_code not in [INTERNAL_ERROR_EXIT_CODE, USER_ERROR_EXIT_CODE]: - return_code = INTERNAL_ERROR_EXIT_CODE - write_termination_log(f"Unhandled exception during training. {e}") - sys.exit(return_code) - except Exception as e: # pylint: disable=broad-except - logging.error(traceback.format_exc()) + output_dir = job_config.get("output_dir") + try: + # checkpoints outputted to tempdir, only final checkpoint copied to output dir + launch_command(args) + except subprocess.CalledProcessError as e: + # If the subprocess throws an exception, the base exception is hidden in the + # subprocess call and is difficult to access at this level. However, that is not + # an issue because sft_trainer.py would have already written the exception + # message to termination log. + logging.error(traceback.format_exc()) + # The exit code that sft_trainer.py threw is captured in e.returncode + + return_code = e.returncode + if return_code not in [INTERNAL_ERROR_EXIT_CODE, USER_ERROR_EXIT_CODE]: + return_code = INTERNAL_ERROR_EXIT_CODE write_termination_log(f"Unhandled exception during training. {e}") - sys.exit(INTERNAL_ERROR_EXIT_CODE) + sys.exit(return_code) + except Exception as e: # pylint: disable=broad-except + logging.error(traceback.format_exc()) + write_termination_log(f"Unhandled exception during training. {e}") + sys.exit(INTERNAL_ERROR_EXIT_CODE) - try: - last_checkpoint_dir = get_highest_checkpoint(tempdir) - last_checkpoint_path = os.path.join(tempdir, last_checkpoint_dir) + # remove lm_head from granite with llama arch models + try: + checkpoint_dir = job_config.get("save_model_dir") + if not checkpoint_dir: + checkpoint_dir = os.path.join( + output_dir, get_highest_checkpoint(output_dir) + ) - use_flash_attn = job_config.get("use_flash_attn", True) - adapter_config_path = os.path.join( - last_checkpoint_path, "adapter_config.json" + use_flash_attn = job_config.get("use_flash_attn", True) + adapter_config_path = os.path.join(checkpoint_dir, "adapter_config.json") + tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir) + + if os.path.exists(adapter_config_path): + base_model_path = get_base_model_from_adapter_config(adapter_config_path) + base_model = AutoModelForCausalLM.from_pretrained( + base_model_path, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=bfloat16 if use_flash_attn else None, ) - tokenizer = AutoTokenizer.from_pretrained(last_checkpoint_path) - if os.path.exists(adapter_config_path): - base_model_path = get_base_model_from_adapter_config( - adapter_config_path - ) - base_model = AutoModelForCausalLM.from_pretrained( - base_model_path, - attn_implementation="flash_attention_2" if use_flash_attn else None, - torch_dtype=bfloat16 if use_flash_attn else None, - ) + # since the peft library (PEFTModelForCausalLM) does not handle cases + # where the model's layers are modified, in our case the embedding layer + # is modified, so we resize the backbone model's embedding layer with our own + # utility before passing it along to load the PEFT model. + tokenizer_data_utils.tokenizer_and_embedding_resize( + {}, tokenizer=tokenizer, model=base_model + ) + model = PeftModel.from_pretrained( + base_model, + checkpoint_dir, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=bfloat16 if use_flash_attn else None, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + checkpoint_dir, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=bfloat16 if use_flash_attn else None, + ) - # since the peft library (PEFTModelForCausalLM) does not handle cases - # where the model's layers are modified, in our case the embedding layer - # is modified, so we resize the backbone model's embedding layer with our own - # utility before passing it along to load the PEFT model. - tokenizer_data_utils.tokenizer_and_embedding_resize( - {}, tokenizer=tokenizer, model=base_model - ) - model = PeftModel.from_pretrained( - base_model, - last_checkpoint_path, - attn_implementation="flash_attention_2" if use_flash_attn else None, - torch_dtype=bfloat16 if use_flash_attn else None, + model_arch = model.config.model_type + # check that it is a granite model with llama architecture with tied weights + # ie. lm_head is duplicate of embeddings + + # a fine tuned model will have params_dict.get("model.embed_tokens.weight") + # a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight") + # a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight") + if model_arch == "llama" and hasattr(model, "lm_head"): + if ( + # lora tuned model has an addt model layer + ( + hasattr(model.model, "model") + and model.lm_head.weight.untyped_storage().data_ptr() + == model.model.model.embed_tokens.weight.untyped_storage().data_ptr() ) - else: - model = AutoModelForCausalLM.from_pretrained( - last_checkpoint_path, - attn_implementation="flash_attention_2" if use_flash_attn else None, - torch_dtype=bfloat16 if use_flash_attn else None, + # prompt tuned model or fine tuned model + or ( + hasattr(model.model, "embed_tokens") + and model.lm_head.weight.untyped_storage().data_ptr() + == model.model.embed_tokens.weight.untyped_storage().data_ptr() ) + ): - model_arch = model.config.model_type - # check that it is a granite model with llama architecture with tied weights - # ie. lm_head is duplicate of embeddings - - # a fine tuned model will have params_dict.get("model.embed_tokens.weight") - # a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight") - # a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight") - copy_checkpoint_bool = True - if model_arch == "llama" and hasattr(model, "lm_head"): - if ( - # lora tuned model has an addt model layer - ( - hasattr(model.model, "model") - and model.lm_head.weight.untyped_storage().data_ptr() - == model.model.model.embed_tokens.weight.untyped_storage().data_ptr() - ) - # prompt tuned model or fine tuned model - or ( - hasattr(model.model, "embed_tokens") - and model.lm_head.weight.untyped_storage().data_ptr() - == model.model.embed_tokens.weight.untyped_storage().data_ptr() - ) - ): - - copy_checkpoint_bool = False - logging.info("Removing lm_head from checkpoint") - del model.lm_head.weight - - if hasattr(model, "lm_head.weight"): - logging.warning("Failed to delete lm_head.weight from model") - - logging.info("Saving checkpoint to %s", original_output_dir) - model.save_pretrained(original_output_dir) - # save tokenizer with model - tokenizer.save_pretrained(original_output_dir) - - # copy last checkpoint into mounted output dir - if copy_checkpoint_bool: - logging.info( - "Copying last checkpoint %s into output dir %s", - last_checkpoint_dir, - original_output_dir, - ) - copy_checkpoint(last_checkpoint_path, original_output_dir) - except Exception as e: # pylint: disable=broad-except - logging.error(traceback.format_exc()) - write_termination_log( - f"Exception encountered writing output model to storage: {e}" - ) - sys.exit(INTERNAL_ERROR_EXIT_CODE) + logging.info("Removing lm_head from checkpoint") + del model.lm_head.weight - # copy over any loss logs - try: - train_logs_filepath = os.path.join( - tempdir, - FileLoggingTrackerConfig.training_logs_filename, - ) - if os.path.exists(train_logs_filepath): - shutil.copy(train_logs_filepath, original_output_dir) - - # The .complete file will signal to users that we are finished copying - # files over - if os.path.exists(original_output_dir): - Path(os.path.join(original_output_dir, ".complete")).touch() - except Exception as e: # pylint: disable=broad-except - logging.error(traceback.format_exc()) - write_termination_log( - f"Exception encountered in capturing training logs: {e}" - ) - sys.exit(INTERNAL_ERROR_EXIT_CODE) + if hasattr(model, "lm_head.weight"): + logging.warning("Failed to delete lm_head.weight from model") + + logging.info("Saving checkpoint to %s", output_dir) + model.save_pretrained(checkpoint_dir) + # save tokenizer with model + tokenizer.save_pretrained(checkpoint_dir) + + except Exception as e: # pylint: disable=broad-except + logging.error(traceback.format_exc()) + write_termination_log(f"Exception encountered removing lm_head from model: {e}") + sys.exit(INTERNAL_ERROR_EXIT_CODE) + + # The .complete file will signal to users that we are finished copying + # files over + if os.path.exists(output_dir): + Path(os.path.join(output_dir, ".complete")).touch() return 0 diff --git a/docs/experiment-tracking.md b/docs/experiment-tracking.md new file mode 100644 index 000000000..edc4e5978 --- /dev/null +++ b/docs/experiment-tracking.md @@ -0,0 +1,133 @@ +# Experiment Tracker + +Experiment tracking is an optional feature of this repo. We have introduced experiment tracking to help in systematically recording hyperparameters, model configurations, and results from each experiment automatically and with the help of third party trackers like [Aimstack](https://aimstack.io). + +Tracking can be enabled by passing a config in the [Training Arguments](https://github.com/foundation-model-stack/fms-hf-tuning/blob/a9b8ec8d1d50211873e63fa4641054f704be8712/tuning/config/configs.py#L131) +with the name of the enabled trackers passed as a list. + +``` +from tuning import sft_trainer + +training_args = TrainingArguments( + ..., + trackers = ["aim", "file_logger"] +) + +sft_trainer.train(train_args=training_args,...) +``` + +For each of the requested trackers the code expects you to pass a config to the `sft_trainer.train` function which can be specified through `tracker_conifgs` argument [here](https://github.com/foundation-model-stack/fms-hf-tuning/blob/a9b8ec8d1d50211873e63fa4641054f704be8712/tuning/sft_trainer.py#L78) details of which are present below. + + + + +## Tracker Configurations + +## File Logging Tracker + +[File Logger](../tuning/trackers/filelogging_tracker.py) is an inbuilt tracker which can be used to dump loss at every log interval to a file. + +Currently `File Logger` is enabled by default and will dump loss at every log interval of a training to a default file path specified [here](../tuning/config/tracker_configs.py) inside the output folder passed during training. + +To override the location of file logger please pass an instance of the [FileLoggingTrackerConfig](../tuning/config/tracker_configs.py) to `tracker_configs` argument. + +``` +from tuning import sft_trainer +from tuning.config.tracker_configs import FileLoggingTrackerConfig, TrackerConfigFactory + +training_args = TrainingArguments( + ..., + trackers = ["file_logger"] +) + + +logs_file = "new_train_logs.jsonl" + +tracker_configs = TrackerConfigFactory( + file_logger_config=FileLoggingTrackerConfig( + training_logs_filename=logs_file + ) + ) + +sft_trainer.train(train_args=training_args, tracker_configs=tracker_configs, ...) +``` + +Currently File Logging tacker supports only one argument and this file will be placed inside the `train_args.output` folder. + +## Aimstack Tracker + +To enable [Aim](https://aimstack.io) users need to pass `"aim"` as the requested tracker as part of the [training argument](https://github.com/foundation-model-stack/fms-hf-tuning/blob/a9b8ec8d1d50211873e63fa4641054f704be8712/tuning/config/configs.py#L131). + + +When using Aimstack, users need to specify additional arguments which specify where the Aimstack database is present and what experiment name to use +for tracking the training. + +Aimstack supports either a local (`filesystem path`) based db location or a remote (`aim_server:port`) based database location. + +See Aim [documentation](https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html) for more details. + +After [initialising a repo](https://aimstack.readthedocs.io/en/latest/quick_start/setup.html#initializing-aim-repository), users can specify the location of the +repo either local or remote. + +For a local aim database where `aim_repo` should point to the path of where the initialized Aimstack repo is present, + +``` +from tuning import sft_trainer +from tuning.config.tracker_configs import AimConfig, TrackerConfigFactory + +training_args = TrainingArguments( + ..., + trackers = ["aim"], +) + +tracker_configs = TrackerConfigFactory( + aim_config=AimConfig( + experiment="experiment-name", + aim_repo= + ) + ) + +sft_trainer.train(train_args=training_args, tracker_configs=tracker_configs,....) +``` + + or, for a remote server where aimstack database is running at `aim://aim_remote_server_ip:aim_remote_server_port` + +``` +from tuning import sft_trainer +from tuning.config.tracker_configs import AimConfig, TrackerConfigFactory + +training_args = TrainingArguments( + ..., + trackers = ["aim"], +) + +tracker_configs = TrackerConfigFactory( + aim_config=AimConfig( + experiment="experiment-name", + aim_remote_server_ip=, + aim_remote_server_port= + ) + ) + +sft_trainer.train(train_args=training_args, tracker_configs=tracker_configs,....) +``` + +The code expects either the `local` or `remote` repo to be specified and will result in a `ValueError` otherwise. +See [AimConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/a9b8ec8d1d50211873e63fa4641054f704be8712/tuning/config/tracker_configs.py#L25) for more details. + + +## Running the code via command line `tuning/sft_trainer::main` function + +If running the code via main function of [sft_trainer.py](../tuning/sft_trainer.py) the arguments to enable and customise trackers can be passed via commandline. + +To enable tracking please pass + +``` +--tracker +``` + +To further customise tracking you can specify additional arguments needed by the tracker like + +``` +--tracker aim --aim_repo --experiment +``` \ No newline at end of file diff --git a/examples/trainercontroller_configs/epoch-level-training-loss-below-threshold-log-config.yaml b/examples/trainercontroller_configs/epoch-level-training-loss-below-threshold-log-config.yaml new file mode 100644 index 000000000..53c8a1777 --- /dev/null +++ b/examples/trainercontroller_configs/epoch-level-training-loss-below-threshold-log-config.yaml @@ -0,0 +1,14 @@ +controller_metrics: + - name: training_loss_window + class: HistoryBasedMetric + arguments: + window_size: 1 +controllers: + - name: epoch_level_stop_on_training_loss_below_threshold + triggers: + - on_step_end + rule: len(training_loss_window["training_loss"]["loss"]) == training_loss_window["window_size"] and training_loss_window["training_loss"]["loss"][0] < 2.2 and training_loss_window["training_loss"]["epoch"][0] > 2 + config: + trigger_log_level: warning + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/examples/trainercontroller_configs/log_controller.yaml b/examples/trainercontroller_configs/log_controller.yaml new file mode 100644 index 000000000..0becdc7e2 --- /dev/null +++ b/examples/trainercontroller_configs/log_controller.yaml @@ -0,0 +1,16 @@ +controller_metrics: + - name: trainer_state + class: TrainingState +operations: + - name: logcontrolstep + class: LogControl + arguments: + log_format: 'This is a test log format [{event_name}] => {trainer_state}' + log_level: warning +controllers: + - name: log-controller-step + triggers: + - on_step_end + rule: 'True' + operations: + - logcontrolstep.should_log \ No newline at end of file diff --git a/examples/trainercontroller_configs/loss.yaml b/examples/trainercontroller_configs/loss.yaml index d7d0baa2b..c4322a6b4 100644 --- a/examples/trainercontroller_configs/loss.yaml +++ b/examples/trainercontroller_configs/loss.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller triggers: - on_log - rule: loss < 1.0 + rule: training_loss["loss"] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3438ecfb4..e31192470 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<24", "ninja>=1.11.1.1,<2.0", "scikit-learn>=1.0, <2.0", "boto3>=1.34, <2.0"] +dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<25", "ninja>=1.11.1.1,<2.0", "scikit-learn>=1.0, <2.0", "boto3>=1.34, <2.0"] flash-attn = ["flash-attn>=2.5.3,<3.0"] aim = ["aim>=3.19.0,<4.0"] fms-accel = ["fms-acceleration>=0.1"] diff --git a/scripts/run_evaluation.py b/scripts/run_evaluation.py index dde162bb8..b488994f0 100644 --- a/scripts/run_evaluation.py +++ b/scripts/run_evaluation.py @@ -62,6 +62,10 @@ def parse_and_validate_args(): help="Whether to load the model using Flash Attention 2", action="store_true", ) + parser.add_argument( + "--base_model_name_or_path", + help="Base model for adapter", + ) parsed_args = parser.parse_args() print(f"Multiclass / multioutput delimiter: {parsed_args.delimiter}") @@ -446,7 +450,20 @@ def export_experiment_info( if __name__ == "__main__": args = parse_and_validate_args() - tuned_model = TunedCausalLM.load(args.model, use_flash_attn=args.use_flash_attn) + + base_model_name_or_path = args.base_model_name_or_path + if not base_model_name_or_path: + adapter_config_path = os.path.join(args.model, "adapter_config.json") + if os.path.exists(adapter_config_path): + with open(adapter_config_path, "r", encoding="utf-8") as config_file: + adapter_config = json.load(config_file) + base_model_name_or_path = adapter_config.get("base_model_name_or_path") + + tuned_model = TunedCausalLM.load( + args.model, + use_flash_attn=args.use_flash_attn, + base_model_name_or_path=base_model_name_or_path, + ) eval_data = datasets.load_dataset( "json", data_files=args.data_path, split=args.split ) diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py index 824b7125c..927af3165 100644 --- a/tests/build/test_launch_script.py +++ b/tests/build/test_launch_script.py @@ -25,12 +25,13 @@ # First Party from build.accelerate_launch import main -from build.utils import serialize_args +from build.utils import serialize_args, get_highest_checkpoint from tests.data import TWITTER_COMPLAINTS_DATA from tuning.utils.error_logging import ( USER_ERROR_EXIT_CODE, INTERNAL_ERROR_EXIT_CODE, ) +from tuning.config.tracker_configs import FileLoggingTrackerConfig SCRIPT = "tuning/sft_trainer.py" MODEL_NAME = "Maykeye/TinyLLama-v0" @@ -61,7 +62,6 @@ "prompt_tuning_init": "RANDOM", "num_virtual_tokens": 8, "prompt_tuning_init_text": "hello", - "tokenizer_name_or_path": MODEL_NAME, "save_strategy": "epoch", "output_dir": "tmp", }, @@ -98,11 +98,9 @@ def test_successful_ft(): os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args assert main() == 0 - # check termination log and .complete files - assert os.path.exists(tempdir + "/termination-log") is False - assert os.path.exists(os.path.join(tempdir, ".complete")) is True - assert os.path.exists(tempdir + "/adapter_config.json") is False - assert len(glob.glob(f"{tempdir}/model*.safetensors")) > 0 + _validate_termination_files_when_tuning_succeeds(tempdir) + checkpoint = os.path.join(tempdir, get_highest_checkpoint(tempdir)) + _validate_training_output(checkpoint, "ft") def test_successful_pt(): @@ -114,11 +112,9 @@ def test_successful_pt(): os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args assert main() == 0 - # check termination log and .complete files - assert os.path.exists(tempdir + "/termination-log") is False - assert os.path.exists(os.path.join(tempdir, ".complete")) is True - assert os.path.exists(tempdir + "/adapter_model.safetensors") is True - assert os.path.exists(tempdir + "/adapter_config.json") is True + _validate_termination_files_when_tuning_succeeds(tempdir) + checkpoint = os.path.join(tempdir, get_highest_checkpoint(tempdir)) + _validate_training_output(checkpoint, "pt") def test_successful_lora(): @@ -130,11 +126,92 @@ def test_successful_lora(): os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args assert main() == 0 - # check termination log and .complete files - assert os.path.exists(tempdir + "/termination-log") is False - assert os.path.exists(os.path.join(tempdir, ".complete")) is True - assert os.path.exists(tempdir + "/adapter_model.safetensors") is True - assert os.path.exists(tempdir + "/adapter_config.json") is True + _validate_termination_files_when_tuning_succeeds(tempdir) + checkpoint = os.path.join(tempdir, get_highest_checkpoint(tempdir)) + _validate_training_output(checkpoint, "lora") + + +def test_lora_save_model_dir_separate_dirs(): + """Run LoRA tuning with separate save_model_dir and output_dir. + Verify model saved to save_model_dir and checkpoints saved to + output_dir. + """ + with tempfile.TemporaryDirectory() as tempdir: + output_dir = os.path.join(tempdir, "output_dir") + save_model_dir = os.path.join(tempdir, "save_model_dir") + setup_env(tempdir) + TRAIN_KWARGS = { + **BASE_LORA_KWARGS, + **{ + "output_dir": output_dir, + "save_model_dir": save_model_dir, + "save_total_limit": 1, + }, + } + serialized_args = serialize_args(TRAIN_KWARGS) + os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args + + assert main() == 0 + _validate_termination_files_when_tuning_succeeds(output_dir) + _validate_training_output(save_model_dir, "lora") + + assert len(os.listdir(output_dir)) == 3 + checkpoints = glob.glob(os.path.join(output_dir, "checkpoint-*")) + assert len(checkpoints) == 1 + + +def test_lora_save_model_dir_same_dir_as_output_dir(): + """Run LoRA tuning with same save_model_dir and output_dir. + Verify checkpoints, logs, and model saved to path. + """ + with tempfile.TemporaryDirectory() as tempdir: + setup_env(tempdir) + TRAIN_KWARGS = { + **BASE_LORA_KWARGS, + **{ + "output_dir": tempdir, + "save_model_dir": tempdir, + "gradient_accumulation_steps": 1, + }, + } + serialized_args = serialize_args(TRAIN_KWARGS) + os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args + + assert main() == 0 + # check logs, checkpoint dir, and model exists in path + _validate_termination_files_when_tuning_succeeds(tempdir) + # check that model exists in output_dir and checkpoint dir + _validate_training_output(tempdir, "lora") + checkpoint_path = os.path.join(tempdir, get_highest_checkpoint(tempdir)) + _validate_training_output(checkpoint_path, "lora") + + # number of checkpoints should equal number of epochs + checkpoints = glob.glob(os.path.join(tempdir, "checkpoint-*")) + assert len(checkpoints) == TRAIN_KWARGS["num_train_epochs"] + + +def test_lora_save_model_dir_same_dir_as_output_dir_save_strategy_no(): + """Run LoRA tuning with same save_model_dir and output_dir and + save_strategy=no. Verify no checkpoints created, only + logs and final model. + """ + with tempfile.TemporaryDirectory() as tempdir: + setup_env(tempdir) + TRAIN_KWARGS = { + **BASE_LORA_KWARGS, + **{"output_dir": tempdir, "save_model_dir": tempdir, "save_strategy": "no"}, + } + serialized_args = serialize_args(TRAIN_KWARGS) + os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args + + assert main() == 0 + # check that model and logs exists in output_dir + _validate_termination_files_when_tuning_succeeds(tempdir) + _validate_training_output(tempdir, "lora") + + # no checkpoints should be created + checkpoints = glob.glob(os.path.join(tempdir, "checkpoint-*")) + assert len(checkpoints) == 0 def test_bad_script_path(): @@ -213,6 +290,27 @@ def test_config_parsing_error(): assert os.stat(tempdir + "/termination-log").st_size > 0 +def _validate_termination_files_when_tuning_succeeds(base_dir): + # check termination log and .complete files + assert os.path.exists(os.path.join(base_dir, "/termination-log")) is False + assert os.path.exists(os.path.join(base_dir, ".complete")) is True + assert ( + os.path.exists( + os.path.join(base_dir, FileLoggingTrackerConfig.training_logs_filename) + ) + is True + ) + + +def _validate_training_output(base_dir, tuning_technique): + if tuning_technique == "ft": + assert len(glob.glob(f"{base_dir}/model*.safetensors")) > 0 + assert os.path.exists(base_dir + "/adapter_config.json") is False + else: + assert os.path.exists(base_dir + "/adapter_config.json") is True + assert os.path.exists(base_dir + "/adapter_model.safetensors") is True + + def test_cleanup(): # This runs to unset env variables that could disrupt other tests cleanup_env() diff --git a/tests/data/__init__.py b/tests/data/__init__.py index cf88ece96..c93187222 100644 --- a/tests/data/__init__.py +++ b/tests/data/__init__.py @@ -19,11 +19,14 @@ ### Constants used for data DATA_DIR = os.path.join(os.path.dirname(__file__)) -TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json") +TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl") TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT = os.path.join( - DATA_DIR, "twitter_complaints_input_output.json" + DATA_DIR, "twitter_complaints_input_output.jsonl" ) TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join(DATA_DIR, "twitter_complaints_json.json") +TWITTER_COMPLAINTS_TOKENIZED = os.path.join( + DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl" +) EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json") MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json") MODEL_NAME = "Maykeye/TinyLLama-v0" diff --git a/tests/data/trainercontroller/__init__.py b/tests/data/trainercontroller/__init__.py index aaaeabe93..18035f102 100644 --- a/tests/data/trainercontroller/__init__.py +++ b/tests/data/trainercontroller/__init__.py @@ -77,3 +77,5 @@ TRAINER_CONFIG_TEST_THRESHOLDED_TRAINING_LOSS_YAML = os.path.join( _DATA_DIR, "thresholded-training-loss.yaml" ) +TRAINER_CONFIG_TEST_ON_SAVE_YAML = os.path.join(_DATA_DIR, "on-save.yaml") +TRAINER_CONFIG_LOG_CONTROLLER_YAML = os.path.join(_DATA_DIR, "log_controller.yaml") diff --git a/tests/data/trainercontroller/log_controller.yaml b/tests/data/trainercontroller/log_controller.yaml new file mode 100644 index 000000000..0becdc7e2 --- /dev/null +++ b/tests/data/trainercontroller/log_controller.yaml @@ -0,0 +1,16 @@ +controller_metrics: + - name: trainer_state + class: TrainingState +operations: + - name: logcontrolstep + class: LogControl + arguments: + log_format: 'This is a test log format [{event_name}] => {trainer_state}' + log_level: warning +controllers: + - name: log-controller-step + triggers: + - on_step_end + rule: 'True' + operations: + - logcontrolstep.should_log \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_custom_operation.yaml b/tests/data/trainercontroller/loss_custom_operation.yaml index 603459234..3ec952a85 100644 --- a/tests/data/trainercontroller/loss_custom_operation.yaml +++ b/tests/data/trainercontroller/loss_custom_operation.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss operations: - name: custom_operation @@ -8,6 +8,6 @@ controllers: - name: loss_controller_custom_operation triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - custom_operation.should_perform_action_xyz \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml b/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml index 3dac47cb2..e0d3a71d3 100644 --- a/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml +++ b/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss operations: - name: custom_operation @@ -8,6 +8,6 @@ controllers: - name: loss_controller_custom_operation_invalid_action triggers: - on_log - rule: loss < 1.0 + rule: training_loss["loss"] < 1.0 operations: - custom_operation.should_ \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_metric.yaml b/tests/data/trainercontroller/loss_invalid_metric.yaml index 4d94878aa..8491175b0 100644 --- a/tests/data/trainercontroller/loss_invalid_metric.yaml +++ b/tests/data/trainercontroller/loss_invalid_metric.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: MissingMetricClass controllers: - name: loss_controller_invalid_metric triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_operation.yaml b/tests/data/trainercontroller/loss_invalid_operation.yaml index f904e27d9..769c9441a 100644 --- a/tests/data/trainercontroller/loss_invalid_operation.yaml +++ b/tests/data/trainercontroller/loss_invalid_operation.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_invalid_operation triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - missingop.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_operation_action.yaml b/tests/data/trainercontroller/loss_invalid_operation_action.yaml index 3015516ef..7d8a17ad0 100644 --- a/tests/data/trainercontroller/loss_invalid_operation_action.yaml +++ b/tests/data/trainercontroller/loss_invalid_operation_action.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_invalid_operation_action triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.missingaction \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_trigger.yaml b/tests/data/trainercontroller/loss_invalid_trigger.yaml index 382ad7783..38abe7ed9 100644 --- a/tests/data/trainercontroller/loss_invalid_trigger.yaml +++ b/tests/data/trainercontroller/loss_invalid_trigger.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_invalid_trigger triggers: - log_it_all_incorrect_trigger_name - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_on_threshold.yaml b/tests/data/trainercontroller/loss_on_threshold.yaml index d7d0baa2b..24891e8ed 100644 --- a/tests/data/trainercontroller/loss_on_threshold.yaml +++ b/tests/data/trainercontroller/loss_on_threshold.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml index 45e2a3eea..cb9bcf957 100644 --- a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml +++ b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml @@ -1,12 +1,12 @@ controller_metrics: - - name: state + - name: trainer_state class: TrainingState - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller triggers: - on_log - rule: loss < 2 and state["epoch"] >= 0.5 + rule: training_loss['loss'] < 2 and trainer_state["epoch"] >= 0.5 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_unavailable_metric.yaml b/tests/data/trainercontroller/loss_unavailable_metric.yaml index 055b93cf3..564184290 100644 --- a/tests/data/trainercontroller/loss_unavailable_metric.yaml +++ b/tests/data/trainercontroller/loss_unavailable_metric.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_unavailable_metric triggers: - on_step_end - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml b/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml index 01495f106..bf8648e93 100644 --- a/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml +++ b/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_wrong_os_rule diff --git a/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml b/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml index 6d5c65328..e2cbb26de 100644 --- a/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml +++ b/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_wrong_input_rule diff --git a/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml b/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml index badcf940a..5ee4bc224 100644 --- a/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml +++ b/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_wrong_os_rule diff --git a/tests/data/trainercontroller/on-save.yaml b/tests/data/trainercontroller/on-save.yaml new file mode 100644 index 000000000..225cba1cc --- /dev/null +++ b/tests/data/trainercontroller/on-save.yaml @@ -0,0 +1,10 @@ +controller_metrics: + - name: trainer_state + class: TrainingState +controllers: + - name: stop_on_training_loss_on_save + triggers: + - on_save + rule: trainer_state["epoch"] >= 0.5 + operations: + - hfcontrols.should_training_stop diff --git a/tests/data/twitter_complaints_input_output.json b/tests/data/twitter_complaints_input_output.jsonl similarity index 100% rename from tests/data/twitter_complaints_input_output.json rename to tests/data/twitter_complaints_input_output.jsonl diff --git a/tests/data/twitter_complaints_small.json b/tests/data/twitter_complaints_small.jsonl similarity index 100% rename from tests/data/twitter_complaints_small.json rename to tests/data/twitter_complaints_small.jsonl diff --git a/tests/data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl b/tests/data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl new file mode 100644 index 000000000..1d56770e3 --- /dev/null +++ b/tests/data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl @@ -0,0 +1,10 @@ +{"Tweet text":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint","input_ids":[1,16121,9211,31871,1662,31866,31856,7416,17632,369,1398,433,322,629,712,1784,13,13,8458,31922,21597,31871,697,9566],"labels":[1,16121,9211,31871,1662,31866,31856,7416,17632,369,1398,433,322,629,712,1784,13,13,8458,31922,21597,31871,697,9566]} +{"Tweet text":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint","input_ids":[1,16121,9211,31871,1662,31892,1260,31825,11273,503,31857,632,5284,365,329,553,1280,31905,960,365,6194,289,11025,31844,365,473,987,12207,4218,389,31822,31853,31854,31886,31852,31852,31854,11300,31847,3873,1507,31843,13,13,8458,31922,21597,31871,697,9566],"labels":[1,16121,9211,31871,1662,31892,1260,31825,11273,503,31857,632,5284,365,329,553,1280,31905,960,365,6194,289,11025,31844,365,473,987,12207,4218,389,31822,31853,31854,31886,31852,31852,31854,11300,31847,3873,1507,31843,13,13,8458,31922,21597,31871,697,9566]} +{"Tweet text":"If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService","ID":2,"Label":1,"text_label":"complaint","output":"### Text: If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService\n\n### Label: complaint","input_ids":[1,16121,9211,31871,960,312,473,31876,31824,685,629,31822,31878,4449,5861,287,1662,1299,1574,1590,31833,263,1360,1299,1574,289,623,31822,31824,16346,312,31876,31836,994,277,3560,567,31843,672,322,260,29458,288,629,14881,31843,2628,1423,1662,31858,601,1662,31858,601,8378,13,13,8458,31922,21597,31871,9566],"labels":[1,16121,9211,31871,960,312,473,31876,31824,685,629,31822,31878,4449,5861,287,1662,1299,1574,1590,31833,263,1360,1299,1574,289,623,31822,31824,16346,312,31876,31836,994,277,3560,567,31843,672,322,260,29458,288,629,14881,31843,2628,1423,1662,31858,601,1662,31858,601,8378,13,13,8458,31922,21597,31871,9566]} +{"Tweet text":"@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.","ID":3,"Label":1,"text_label":"complaint","output":"### Text: @EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.\n\n### Label: complaint","input_ids":[1,16121,9211,31871,1662,7766,1078,8123,17561,308,3456,1833,975,10849,291,4372,15379,504,10011,2368,1512,31822,31855,31852,31852,1243,31843,3007,322,433,31843,13,13,8458,31922,21597,31871,9566],"labels":[1,16121,9211,31871,1662,7766,1078,8123,17561,308,3456,1833,975,10849,291,4372,15379,504,10011,2368,1512,31822,31855,31852,31852,1243,31843,3007,322,433,31843,13,13,8458,31922,21597,31871,9566]} +{"Tweet text":"Couples wallpaper, so cute. :) #BrothersAtHome","ID":4,"Label":2,"text_label":"no complaint","output":"### Text: Couples wallpaper, so cute. :) #BrothersAtHome\n\n### Label: no complaint","input_ids":[1,16121,9211,31871,12371,2208,26657,31844,560,14138,31843,21994,1257,24870,496,31829,8198,19057,13,13,8458,31922,21597,31871,697,9566],"labels":[1,16121,9211,31871,12371,2208,26657,31844,560,14138,31843,21994,1257,24870,496,31829,8198,19057,13,13,8458,31922,21597,31871,697,9566]} +{"Tweet text":"@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG","ID":5,"Label":2,"text_label":"no complaint","output":"### Text: @mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG\n\n### Label: no complaint","input_ids":[1,16121,9211,31871,1662,31836,651,307,395,13094,672,1467,701,333,515,31844,504,1097,2266,282,305,781,31902,21626,31822,31824,5540,397,560,5253,662,365,31876,263,4985,31854,8903,16801,291,612,31925,2011,1129,31824,31843,1358,31873,19919,31824,31865,31829,469,2131,31874,13,13,8458,31922,21597,31871,697,9566],"labels":[1,16121,9211,31871,1662,31836,651,307,395,13094,672,1467,701,333,515,31844,504,1097,2266,282,305,781,31902,21626,31822,31824,5540,397,560,5253,662,365,31876,263,4985,31854,8903,16801,291,612,31925,2011,1129,31824,31843,1358,31873,19919,31824,31865,31829,469,2131,31874,13,13,8458,31922,21597,31871,697,9566]} +{"Tweet text":"@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?","ID":6,"Label":2,"text_label":"no complaint","output":"### Text: @Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?\n\n### Label: no complaint","input_ids":[1,16121,9211,31871,1662,31900,307,31837,473,382,685,266,3195,17532,329,260,1173,9363,352,1671,1881,646,619,31822,31882,5556,504,2091,31822,31882,31843,31855,31861,405,499,382,863,260,31822,31878,4449,2540,2042,31902,13,13,8458,31922,21597,31871,697,9566],"labels":[1,16121,9211,31871,1662,31900,307,31837,473,382,685,266,3195,17532,329,260,1173,9363,352,1671,1881,646,619,31822,31882,5556,504,2091,31822,31882,31843,31855,31861,405,499,382,863,260,31822,31878,4449,2540,2042,31902,13,13,8458,31922,21597,31871,697,9566]} +{"Tweet text":"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?","ID":7,"Label":1,"text_label":"complaint","output":"### Text: @nationalgridus I have no water and the bill is current and paid. Can you do something about this?\n\n### Label: complaint","input_ids":[1,16121,9211,31871,1662,14390,16373,337,312,435,697,1579,291,266,3925,322,1434,291,3877,31843,1456,365,499,1419,562,433,31902,13,13,8458,31922,21597,31871,9566],"labels":[1,16121,9211,31871,1662,14390,16373,337,312,435,697,1579,291,266,3925,322,1434,291,3877,31843,1456,365,499,1419,562,433,31902,13,13,8458,31922,21597,31871,9566]} +{"Tweet text":"Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora","ID":8,"Label":1,"text_label":"complaint","output":"### Text: Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora\n\n### Label: complaint","input_ids":[1,16121,9211,31871,7265,7550,389,1662,31856,2226,11596,27771,898,31843,3259,647,312,498,288,635,31844,518,3822,397,2168,28910,31873,13627,4107,1708,31843,312,31876,608,1090,629,10279,289,1662,29966,31831,5605,13,13,8458,31922,21597,31871,9566],"labels":[1,16121,9211,31871,7265,7550,389,1662,31856,2226,11596,27771,898,31843,3259,647,312,498,288,635,31844,518,3822,397,2168,28910,31873,13627,4107,1708,31843,312,31876,608,1090,629,10279,289,1662,29966,31831,5605,13,13,8458,31922,21597,31871,9566]} +{"Tweet text":"@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd","ID":9,"Label":2,"text_label":"no complaint","output":"### Text: @JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd\n\n### Label: no complaint","input_ids":[1,16121,9211,31871,1662,31884,1450,7064,31847,6538,30894,4472,289,362,828,31843,864,685,541,9932,843,584,18694,31986,13,13,8458,31922,21597,31871,697,9566],"labels":[1,16121,9211,31871,1662,31884,1450,7064,31847,6538,30894,4472,289,362,828,31843,864,685,541,9932,843,584,18694,31986,13,13,8458,31922,21597,31871,697,9566]} diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 7c96ccceb..0264d3b3b 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -29,18 +29,22 @@ import transformers # First Party +from build.utils import serialize_args from scripts.run_inference import TunedCausalLM from tests.data import ( EMPTY_DATA, MALFORMATTED_DATA, MODEL_NAME, TWITTER_COMPLAINTS_DATA, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT, TWITTER_COMPLAINTS_JSON_FORMAT, + TWITTER_COMPLAINTS_TOKENIZED, ) # Local from tuning import sft_trainer from tuning.config import configs, peft_config +from tuning.config.tracker_configs import FileLoggingTrackerConfig MODEL_ARGS = configs.ModelArguments( model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32" @@ -157,14 +161,13 @@ def test_parse_arguments_peft_method(job_config): parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig) + assert not tune_config.target_modules + assert "target_modules" not in job_config_lora ############################# Prompt Tuning Tests ############################# -@pytest.mark.skip( - reason="currently inference doesn't work with transformer version 4.42.4" -) def test_run_causallm_pt_and_inference(): """Check if we can bootstrap and peft tune causallm models""" with tempfile.TemporaryDirectory() as tempdir: @@ -177,11 +180,9 @@ def test_run_causallm_pt_and_inference(): _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) adapter_config = _get_adapter_config(checkpoint_path) - # tokenizer_name_or_path from model arguments is passed - # while preparing the prompt tuning config which - # defaults to model_name_or_path if not explicitly set. + _validate_adapter_config( - adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path + adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path ) # Load the model @@ -195,9 +196,6 @@ def test_run_causallm_pt_and_inference(): assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference -@pytest.mark.skip( - reason="currently inference doesn't work with transformer version 4.42.4" -) def test_run_causallm_pt_and_inference_with_formatting_data(): """Check if we can bootstrap and peft tune causallm models This test needs the trainer to format data to a single sequence internally. @@ -218,11 +216,8 @@ def test_run_causallm_pt_and_inference_with_formatting_data(): _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) adapter_config = _get_adapter_config(checkpoint_path) - # tokenizer_name_or_path from model arguments is passed - # while preparing the prompt tuning config which - # defaults to model_name_or_path if not explicitly set. _validate_adapter_config( - adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path + adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path ) # Load the model @@ -236,9 +231,6 @@ def test_run_causallm_pt_and_inference_with_formatting_data(): assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference -@pytest.mark.skip( - reason="currently inference doesn't work with transformer version 4.42.4" -) def test_run_causallm_pt_and_inference_JSON_file_formatter(): """Check if we can bootstrap and peft tune causallm models with JSON train file format""" with tempfile.TemporaryDirectory() as tempdir: @@ -257,11 +249,8 @@ def test_run_causallm_pt_and_inference_JSON_file_formatter(): _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) adapter_config = _get_adapter_config(checkpoint_path) - # tokenizer_name_or_path from model arguments is passed - # while preparing the prompt tuning config which - # defaults to model_name_or_path if not explicitly set. _validate_adapter_config( - adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path + adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path ) # Load the model @@ -292,11 +281,8 @@ def test_run_causallm_pt_init_text(): _validate_training(tempdir) checkpoint_path = _get_checkpoint_path(tempdir) adapter_config = _get_adapter_config(checkpoint_path) - # tokenizer_name_or_path from model arguments is passed - # while preparing the prompt tuning config which - # defaults to model_name_or_path if not explicitly set. _validate_adapter_config( - adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path + adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path ) @@ -356,6 +342,20 @@ def test_run_causallm_pt_with_validation_data_formatting(): _validate_training(tempdir, check_eval=True) +def test_run_causallm_pt_with_custom_tokenizer(): + """Check if we fail when custom tokenizer not having pad token is used in prompt tuning""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + model_args = copy.deepcopy(MODEL_ARGS) + model_args.tokenizer_name_or_path = model_args.model_name_or_path + train_args.output_dir = tempdir + train_args.eval_strategy = "epoch" + data_args = copy.deepcopy(DATA_ARGS) + data_args.validation_data_path = TWITTER_COMPLAINTS_DATA + with pytest.raises(ValueError): + sft_trainer.train(model_args, data_args, train_args, PEFT_PT_ARGS) + + ############################# Lora Tests ############################# target_modules_val_map = [ @@ -407,12 +407,103 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected): assert "Simply put, the theory of relativity states that" in output_inference +def test_successful_lora_target_modules_default_from_main(): + """Check that if target_modules is not set, or set to None via JSON, the + default value by model type will be using in LoRA tuning. + The correct default target modules will be used for model type llama + and will exist in the resulting adapter_config.json. + https://github.com/huggingface/peft/blob/7b1c08d2b5e13d3c99b7d6ee83eab90e1216d4ba/ + src/peft/tuners/lora/model.py#L432 + """ + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **MODEL_ARGS.__dict__, + **TRAIN_ARGS.__dict__, + **DATA_ARGS.__dict__, + **PEFT_LORA_ARGS.__dict__, + **{"peft_method": "lora", "output_dir": tempdir}, + } + serialized_args = serialize_args(TRAIN_KWARGS) + os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args + + sft_trainer.main() + + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + adapter_config = _get_adapter_config(checkpoint_path) + _validate_adapter_config(adapter_config, "LORA") + + assert ( + "target_modules" in adapter_config + ), "target_modules not found in adapter_config.json." + + assert set(adapter_config.get("target_modules")) == { + "q_proj", + "v_proj", + }, "target_modules are not set to the default values." + + ############################# Finetuning Tests ############################# def test_run_causallm_ft_and_inference(): - """Check if we can bootstrap and finetune tune causallm models""" + """Check if we can bootstrap and finetune causallm models""" with tempfile.TemporaryDirectory() as tempdir: _test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir) - _test_run_inference(tempdir=tempdir) + _test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir)) + + +def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no(): + """Check if we can bootstrap and finetune causallm model with save_model_dir + and save_strategy=no. Verify no checkpoints created and can save model. + """ + with tempfile.TemporaryDirectory() as tempdir: + save_model_args = copy.deepcopy(TRAIN_ARGS) + save_model_args.save_strategy = "no" + save_model_args.output_dir = tempdir + + trainer = sft_trainer.train(MODEL_ARGS, DATA_ARGS, save_model_args, None) + logs_path = os.path.join( + tempdir, FileLoggingTrackerConfig.training_logs_filename + ) + _validate_logfile(logs_path) + # validate that no checkpoints created + assert not any(x.startswith("checkpoint-") for x in os.listdir(tempdir)) + + sft_trainer.save(tempdir, trainer) + assert any(x.endswith(".safetensors") for x in os.listdir(tempdir)) + _test_run_inference(checkpoint_path=tempdir) + + +def test_run_causallm_ft_pretokenized(): + """Check if we can bootstrap and finetune causallm models using pretokenized data""" + with tempfile.TemporaryDirectory() as tempdir: + data_formatting_args = copy.deepcopy(DATA_ARGS) + + # below args not needed for pretokenized data + data_formatting_args.data_formatter_template = None + data_formatting_args.dataset_text_field = None + data_formatting_args.response_template = None + + # update the training data path to tokenized data + data_formatting_args.training_data_path = TWITTER_COMPLAINTS_TOKENIZED + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args) + + # validate full ft configs + _validate_training(tempdir) + checkpoint_path = _get_checkpoint_path(tempdir) + + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference ############################# Helper functions ############################# @@ -425,9 +516,7 @@ def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): _validate_training(tempdir) -def _test_run_inference(tempdir): - checkpoint_path = _get_checkpoint_path(tempdir) - +def _test_run_inference(checkpoint_path): # Load the model loaded_model = TunedCausalLM.load(checkpoint_path) @@ -444,12 +533,16 @@ def _validate_training( ): assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir)) train_logs_file_path = "{}/{}".format(tempdir, train_logs_file) + _validate_logfile(train_logs_file_path, check_eval) + + +def _validate_logfile(log_file_path, check_eval=False): train_log_contents = "" - with open(train_logs_file_path, encoding="utf-8") as f: + with open(log_file_path, encoding="utf-8") as f: train_log_contents = f.read() - assert os.path.exists(train_logs_file_path) is True - assert os.path.getsize(train_logs_file_path) > 0 + assert os.path.exists(log_file_path) is True + assert os.path.getsize(log_file_path) > 0 assert "training_loss" in train_log_contents if check_eval: @@ -724,3 +817,58 @@ def test_run_with_good_experimental_metadata(): additional_callbacks=[TrainerCallback()], exp_metadata=metadata, ) + + +### Tests for pretokenized data +def test_pretokenized_dataset(): + """Ensure that we can provide a pretokenized dataset with input/output format.""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + data_args = copy.deepcopy(DATA_ARGS) + data_args.dataset_text_field = None + data_args.response_template = None + data_args.training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) + _validate_training(tempdir) + + +@pytest.mark.parametrize( + "dataset_text_field,response_template", + [ + ("foo", None), + (None, "bar"), + ], +) +def test_pretokenized_dataset_bad_args(dataset_text_field, response_template): + """Ensure that we can't provide only dataset text field / response template for pretok data.""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + data_args = copy.deepcopy(DATA_ARGS) + data_args.dataset_text_field = dataset_text_field + data_args.response_template = response_template + data_args.training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT + # We should raise an error since we should not have a dataset text + # field or a response template if we have pretokenized data + with pytest.raises(ValueError): + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) + + +def test_pretokenized_dataset_wrong_format(): + """Ensure that we fail to generate data if the data is in the wrong format.""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + data_args = copy.deepcopy(DATA_ARGS) + data_args.dataset_text_field = None + data_args.response_template = None + data_args.training_data_path = TWITTER_COMPLAINTS_DATA + + # It would be best to handle this in a way that is more understandable; we might + # need to directly add validation prior to the dataset generation since datasets + # is essentially swallowing a KeyError here. + with pytest.raises(ValueError): + sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) diff --git a/tests/trackers/test_aim_tracker.py b/tests/trackers/test_aim_tracker.py index f0002e9ff..d2aa301b7 100644 --- a/tests/trackers/test_aim_tracker.py +++ b/tests/trackers/test_aim_tracker.py @@ -30,6 +30,7 @@ DATA_ARGS, MODEL_ARGS, TRAIN_ARGS, + _get_checkpoint_path, _test_run_inference, _validate_training, ) @@ -98,7 +99,7 @@ def test_e2e_run_with_aim_tracker(aimrepo): _validate_training(tempdir) # validate inference - _test_run_inference(tempdir) + _test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir)) @pytest.mark.skipif(aim_not_available, reason="Requires aimstack to be installed") diff --git a/tests/trackers/test_file_logging_tracker.py b/tests/trackers/test_file_logging_tracker.py index 2129e4927..e5e62ab8b 100644 --- a/tests/trackers/test_file_logging_tracker.py +++ b/tests/trackers/test_file_logging_tracker.py @@ -25,6 +25,7 @@ DATA_ARGS, MODEL_ARGS, TRAIN_ARGS, + _get_checkpoint_path, _test_run_causallm_ft, _test_run_inference, _validate_training, @@ -44,7 +45,7 @@ def test_run_with_file_logging_tracker(): train_args.trackers = ["file_logger"] _test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir) - _test_run_inference(tempdir=tempdir) + _test_run_inference(_get_checkpoint_path(tempdir)) def test_sample_run_with_file_logger_updated_filename(): diff --git a/tests/trainercontroller/custom_operation.py b/tests/trainercontroller/custom_operation.py index 2c402fa96..522200b49 100644 --- a/tests/trainercontroller/custom_operation.py +++ b/tests/trainercontroller/custom_operation.py @@ -26,13 +26,6 @@ class CustomOperation(Operation): """Implements a custom operation for testing""" - def __init__(self, **_): - """Initializes the custom operation class. - Args: - kwargs: List of arguments (key, value)-pairs - """ - super().__init__() - def should_perform_action_xyz(self, control: TrainerControl, **_): """This method performs a set training stop flag action. diff --git a/tests/trainercontroller/custom_operation_invalid_action.py b/tests/trainercontroller/custom_operation_invalid_action.py index 5c04199d3..6871a64fd 100644 --- a/tests/trainercontroller/custom_operation_invalid_action.py +++ b/tests/trainercontroller/custom_operation_invalid_action.py @@ -26,13 +26,6 @@ class CustomOperationInvalidAction(Operation): """Implements a custom operation for testing""" - def __init__(self, **_): - """Initializes the custom operation class. - Args: - kwargs: List of arguments (key, value)-pairs - """ - super().__init__() - def should_(self, control: TrainerControl, **_): """This method defines an action within an invalid name. diff --git a/tests/trainercontroller/test_tuning_trainercontroller.py b/tests/trainercontroller/test_tuning_trainercontroller.py index 7f98ace94..ba1a05808 100644 --- a/tests/trainercontroller/test_tuning_trainercontroller.py +++ b/tests/trainercontroller/test_tuning_trainercontroller.py @@ -138,6 +138,39 @@ def test_thresholded_training_loss(): assert control.should_training_stop is True +def test_thresholded_training_loss_on_save(): + """Tests the thresholded training loss example in + `examples/trainer-controller-configs/on-save.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback(td.TRAINER_CONFIG_TEST_ON_SAVE_YAML) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) + # Trigger rule and test the condition + tc_callback.on_save(args=test_data.args, state=test_data.states[2], control=control) + assert control.should_training_stop is True + + +def test_log_controller(caplog): + """Tests the expose metric scenario example in + `examples/trainer-controller-configs/log_controller.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback(td.TRAINER_CONFIG_LOG_CONTROLLER_YAML) + control = TrainerControl(should_log=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) + tc_callback.on_step_end( + args=test_data.args, state=test_data.states[2], control=control + ) + assert "This is a test log format" in caplog.text + + def test_non_decreasing_training_loss(): """Tests the non-decreasing training loss example in `examples/trainer-controller-configs/non-decreasing-training-loss.yaml` diff --git a/tests/utils/test_config_utils.py b/tests/utils/test_config_utils.py new file mode 100644 index 000000000..1cbbbaaa0 --- /dev/null +++ b/tests/utils/test_config_utils.py @@ -0,0 +1,234 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +import base64 +import os +import pickle + +# Third Party +from peft import LoraConfig, PromptTuningConfig +import pytest + +# First Party +from tests.build.test_utils import HAPPY_PATH_DUMMY_CONFIG_PATH + +# Local +from tuning.config import peft_config +from tuning.utils import config_utils + + +def test_get_hf_peft_config_returns_None_for_tuning_config_None(): + """Test that when tuning_config is None, the function returns None""" + expected_config = None + assert expected_config == config_utils.get_hf_peft_config("", None, "") + + +def test_get_hf_peft_config_returns_lora_config_correctly(): + """Test that tuning_config fields are passed to LoraConfig correctly, + If not defined, the default values are used + """ + tuning_config = peft_config.LoraConfig(r=3, lora_alpha=3) + + config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "") + assert isinstance(config, LoraConfig) + assert config.task_type == "CAUSAL_LM" + assert config.r == 3 + assert config.lora_alpha == 3 + assert ( + config.lora_dropout == 0.05 + ) # default value from local peft_config.LoraConfig + assert ( + config.target_modules is None + ) # default value from local peft_config.LoraConfig + assert config.init_lora_weights is True # default value from HF peft.LoraConfig + assert ( + config.megatron_core == "megatron.core" + ) # default value from HF peft.LoraConfig + + +def test_get_hf_peft_config_ignores_tokenizer_path_for_lora_config(): + """Test that if tokenizer is given with a LoraConfig, it is ignored""" + tuning_config = peft_config.LoraConfig(r=3, lora_alpha=3) + + config = config_utils.get_hf_peft_config( + task_type="CAUSAL_LM", + tuning_config=tuning_config, + tokenizer_name_or_path="foo/bar/path", + ) + assert isinstance(config, LoraConfig) + assert config.task_type == "CAUSAL_LM" + assert config.r == 3 + assert config.lora_alpha == 3 + assert not hasattr(config, "tokenizer_name_or_path") + + +def test_get_hf_peft_config_returns_lora_config_with_correct_value_for_all_linear(): + """Test that when target_modules is ["all-linear"], we convert it to str type "all-linear" """ + tuning_config = peft_config.LoraConfig(r=234, target_modules=["all-linear"]) + + config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "") + assert isinstance(config, LoraConfig) + assert config.target_modules == "all-linear" + + +def test_get_hf_peft_config_returns_pt_config_correctly(): + """Test that the prompt tuning config is set properly for each field + When a value is not defined, the default values are used + """ + tuning_config = peft_config.PromptTuningConfig(num_virtual_tokens=12) + + config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "foo/bar/path") + assert isinstance(config, PromptTuningConfig) + assert config.task_type == "CAUSAL_LM" + assert ( + config.prompt_tuning_init == "TEXT" + ) # default value from local peft_config.PromptTuningConfig + assert config.num_virtual_tokens == 12 + assert ( + config.prompt_tuning_init_text == "Classify if the tweet is a complaint or not:" + ) # default value from local peft_config.PromptTuningConfig + assert config.tokenizer_name_or_path == "foo/bar/path" + assert config.num_layers is None # default value from HF peft.PromptTuningConfig + assert ( + config.inference_mode is False + ) # default value from HF peft.PromptTuningConfig + + +def test_get_hf_peft_config_returns_pt_config_with_correct_tokenizer_path(): + """Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT + Reference: + https://github.com/huggingface/peft/blob/main/src/peft/tuners/prompt_tuning/config.py#L73 + """ + + # When prompt_tuning_init is not TEXT, we can pass in None for tokenizer path + tuning_config = peft_config.PromptTuningConfig(prompt_tuning_init="RANDOM") + config = config_utils.get_hf_peft_config( + task_type=None, tuning_config=tuning_config, tokenizer_name_or_path=None + ) + assert isinstance(config, PromptTuningConfig) + assert config.tokenizer_name_or_path is None + + # When prompt_tuning_init is TEXT, exception is raised if tokenizer path is None + tuning_config = peft_config.PromptTuningConfig(prompt_tuning_init="TEXT") + with pytest.raises(ValueError) as err: + config_utils.get_hf_peft_config( + task_type=None, tuning_config=tuning_config, tokenizer_name_or_path=None + ) + assert "tokenizer_name_or_path can't be None" in err.value + + +def test_create_tuning_config_for_peft_method_lora(): + """Test that LoraConfig is created for peft_method Lora + and fields are set properly. + If unknown fields are passed, they are ignored + """ + tune_config = config_utils.create_tuning_config("lora", foo="x", r=234) + assert isinstance(tune_config, peft_config.LoraConfig) + assert tune_config.r == 234 + assert tune_config.lora_alpha == 32 + assert tune_config.lora_dropout == 0.05 + assert not hasattr(tune_config, "foo") + + +def test_create_tuning_config_for_peft_method_pt(): + """Test that PromptTuningConfig is created for peft_method pt + and fields are set properly + """ + tune_config = config_utils.create_tuning_config( + "pt", foo="x", prompt_tuning_init="RANDOM" + ) + assert isinstance(tune_config, peft_config.PromptTuningConfig) + assert tune_config.prompt_tuning_init == "RANDOM" + + +def test_create_tuning_config_for_peft_method_none(): + """Test that PromptTuningConfig is created for peft_method "None" or None""" + tune_config = config_utils.create_tuning_config("None") + assert tune_config is None + + tune_config = config_utils.create_tuning_config(None) + assert tune_config is None + + +def test_create_tuning_config_does_not_recognize_any_other_peft_method(): + """Test that PromptTuningConfig is created for peft_method "None" or None, + "lora" or "pt", and no other + """ + with pytest.raises(AssertionError) as err: + config_utils.create_tuning_config("hello", foo="x") + assert err.value == "peft config hello not defined in peft.py" + + +def test_update_config_can_handle_dot_for_nested_field(): + """Test that the function can read dotted field for kwargs fields""" + config = peft_config.LoraConfig(r=5) + assert config.lora_alpha == 32 # default value is 32 + + # update lora_alpha to 98 + kwargs = {"LoraConfig.lora_alpha": 98} + config_utils.update_config(config, **kwargs) + assert config.lora_alpha == 98 + + +def test_update_config_does_nothing_for_unknown_field(): + """Test that the function does not change other config + field values if a kwarg field is unknown + """ + # foobar is an unknown field + config = peft_config.LoraConfig(r=5) + kwargs = {"LoraConfig.foobar": 98} + config_utils.update_config(config, **kwargs) # nothing happens + assert config.r == 5 # did not change r value + assert not hasattr(config, "foobar") + + +def test_update_config_can_handle_multiple_config_updates(): + """Test that the function can handle a tuple of configs""" + config = (peft_config.LoraConfig(r=5), peft_config.LoraConfig(r=7)) + kwargs = {"r": 98} + config_utils.update_config(config, **kwargs) + assert config[0].r == 98 + assert config[1].r == 98 + + +def test_get_json_config_can_load_from_path(): + """Test that the function get_json_config can read + the json path from env var SFT_TRAINER_CONFIG_JSON_PATH + """ + if "SFT_TRAINER_CONFIG_JSON_ENV_VAR" in os.environ: + del os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] + os.environ["SFT_TRAINER_CONFIG_JSON_PATH"] = HAPPY_PATH_DUMMY_CONFIG_PATH + + job_config = config_utils.get_json_config() + assert job_config is not None + assert job_config["model_name_or_path"] == "bigscience/bloom-560m" + + +def test_get_json_config_can_load_from_envvar(): + """Test that the function get_json_config can read + the json path from env var SFT_TRAINER_CONFIG_JSON_ENV_VAR + """ + config_json = {"model_name_or_path": "foobar"} + message_bytes = pickle.dumps(config_json) + base64_bytes = base64.b64encode(message_bytes) + encoded_json = base64_bytes.decode("ascii") + os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = encoded_json + + job_config = config_utils.get_json_config() + assert job_config is not None + assert job_config["model_name_or_path"] == "foobar" diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py new file mode 100644 index 000000000..7b7aa1a2a --- /dev/null +++ b/tests/utils/test_logging.py @@ -0,0 +1,84 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from unittest import mock +import copy +import logging +import os + +# First Party +from tests.test_sft_trainer import TRAIN_ARGS + +# Local +from tuning.utils.logging import set_log_level + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_set_log_level_for_logger_default(): + """ + Ensure that the correct log level is being set for python native logger and + transformers logger when no env var or CLI flag is passed + """ + + train_args = copy.deepcopy(TRAIN_ARGS) + training_args, logger = set_log_level(train_args) + assert logger.getEffectiveLevel() == logging.WARNING + assert training_args.log_level == "passive" + + +@mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True) +def test_set_log_level_for_logger_with_env_var(): + """ + Ensure that the correct log level is being set for python native logger and + transformers logger when env var LOG_LEVEL is used + """ + + train_args_env = copy.deepcopy(TRAIN_ARGS) + training_args, logger = set_log_level(train_args_env) + assert logger.getEffectiveLevel() == logging.INFO + assert training_args.log_level == "info" + + +@mock.patch.dict(os.environ, {"TRANSFORMERS_VERBOSITY": "info"}, clear=True) +def test_set_log_level_for_logger_with_set_verbosity_and_cli(): + """ + Ensure that the correct log level is being set for python native logger and + log_level of transformers logger is unchanged when env var TRANSFORMERS_VERBOSITY is used + and CLI flag is passed + """ + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.log_level = "error" + training_args, logger = set_log_level(train_args) + assert logger.getEffectiveLevel() == logging.ERROR + assert training_args.log_level == "error" + + +@mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True) +def test_set_log_level_for_logger_with_env_var_and_cli(): + """ + Ensure that the correct log level is being set for python native logger and + transformers logger when env var LOG_LEVEL is used and CLI flag is passed. + In this case, CLI arg takes precedence over the set env var LOG_LEVEL. + """ + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.log_level = "error" + training_args, logger = set_log_level(train_args) + assert logger.getEffectiveLevel() == logging.ERROR + assert training_args.log_level == "error" diff --git a/tests/utils/test_preprocessing_utils.py b/tests/utils/test_preprocessing_utils.py index e24cf710a..770d999ce 100644 --- a/tests/utils/test_preprocessing_utils.py +++ b/tests/utils/test_preprocessing_utils.py @@ -11,6 +11,7 @@ MODEL_NAME, TWITTER_COMPLAINTS_DATA, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT, + TWITTER_COMPLAINTS_TOKENIZED, ) # Local @@ -18,9 +19,10 @@ from tuning.utils.preprocessing_utils import ( combine_sequence, format_dataset, - get_data_trainer_kwargs, + get_data_collator, get_formatted_dataset_with_single_sequence, get_preprocessed_dataset, + is_pretokenized_dataset, load_hf_dataset_from_jsonl_file, validate_data_args, ) @@ -42,6 +44,24 @@ def test_combine_sequence(input_element, output_element, expected_res): assert comb_seq == expected_res +@pytest.mark.parametrize( + "input_element,output_element,expected_res", + [ + ("foo ", "bar", "foo bar"), + ("foo\n", "bar", "foo\nbar"), + ("foo\t", "bar", "foo\tbar"), + ("foo", "bar", "foo bar"), + ], +) +def test_combine_sequence_adds_eos(input_element, output_element, expected_res): + """Ensure that input / output elements are combined with correct whitespace handling.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token) + expected_res += tokenizer.eos_token + assert isinstance(comb_seq, str) + assert comb_seq == expected_res + + # Tests for loading the dataset from disk def test_load_hf_dataset_from_jsonl_file(): input_field_name = "Tweet text" @@ -108,80 +128,76 @@ def test_get_preprocessed_dataset(max_sequence_length): assert key_lengths.pop() <= max_sequence_length -# Tests for fetching train args @pytest.mark.parametrize( - "use_validation_data, collator_type, packing", + "packing, response_template, formatted_train_dataset, max_seq_length, expected_collator", [ - (True, None, True), - (False, None, True), - (True, DataCollatorForCompletionOnlyLM, False), - (False, DataCollatorForCompletionOnlyLM, False), + ( + False, + "\n### Label:", + load_hf_dataset_from_jsonl_file( + TWITTER_COMPLAINTS_DATA, + input_field_name="Tweet text", + output_field_name="text_label", + ), + 1024, + DataCollatorForCompletionOnlyLM, + ), + ( + False, + None, + Dataset.from_list( + [ + { + "input_ids": [9437, 29, 210], + "attention_mask": [1, 1, 1], + "labels": [1, 20, 30], + } + ] + ), + 1024, + DataCollatorForSeq2Seq, + ), ], ) -def test_get_trainer_kwargs_with_response_template_and_text_field( - use_validation_data, collator_type, packing +def test_get_data_collator( + packing, + response_template, + formatted_train_dataset, + max_seq_length, + expected_collator, ): - training_data_path = TWITTER_COMPLAINTS_DATA - validation_data_path = training_data_path if use_validation_data else None - # Expected columns in the raw loaded dataset for the twitter data - column_names = set(["Tweet text", "ID", "Label", "text_label", "output"]) - trainer_kwargs = get_data_trainer_kwargs( - training_data_path=training_data_path, - validation_data_path=validation_data_path, - packing=packing, - response_template="\n### Label:", - max_sequence_length=100, - tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME), - dataset_text_field="output", - ) - assert len(trainer_kwargs) == 3 - # If we are packing, we should not have a data collator - if collator_type is None: - assert trainer_kwargs["data_collator"] is None - else: - assert isinstance(trainer_kwargs["data_collator"], collator_type) - - # We should only have a validation dataset if one is present - if validation_data_path is None: - assert trainer_kwargs["eval_dataset"] is None - else: - assert isinstance(trainer_kwargs["eval_dataset"], Dataset) - assert set(trainer_kwargs["eval_dataset"].column_names) == column_names - - assert isinstance(trainer_kwargs["train_dataset"], Dataset) - assert set(trainer_kwargs["train_dataset"].column_names) == column_names - - -@pytest.mark.parametrize("use_validation_data", [True, False]) -def test_get_trainer_kwargs_with_custom_masking(use_validation_data): - training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT - validation_data_path = training_data_path if use_validation_data else None - # Expected columns in the raw loaded dataset for the twitter data - column_names = set(["input_ids", "attention_mask", "labels"]) - trainer_kwargs = get_data_trainer_kwargs( - training_data_path=training_data_path, - validation_data_path=validation_data_path, - packing=False, - response_template=None, - max_sequence_length=100, - tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME), - dataset_text_field=None, + """Ensure that the correct collator type is fetched based on the data args""" + collator = get_data_collator( + packing, + response_template, + AutoTokenizer.from_pretrained(MODEL_NAME), + formatted_train_dataset, + max_seq_length, ) - assert len(trainer_kwargs) == 4 - # If we are packing, we should not have a data collator - assert isinstance(trainer_kwargs["data_collator"], DataCollatorForSeq2Seq) + assert isinstance(collator, expected_collator) - # We should only have a validation dataset if one is present - if validation_data_path is None: - assert trainer_kwargs["eval_dataset"] is None - else: - assert isinstance(trainer_kwargs["eval_dataset"], Dataset) - assert set(trainer_kwargs["eval_dataset"].column_names) == column_names - assert isinstance(trainer_kwargs["train_dataset"], Dataset) - assert set(trainer_kwargs["train_dataset"].column_names) == column_names - # Needed to sidestep TRL validation - assert trainer_kwargs["formatting_func"] is not None +@pytest.mark.parametrize( + "data, result", + [ + (TWITTER_COMPLAINTS_DATA, False), + ( + Dataset.from_list( + [ + { + "input_ids": [9437, 29, 210], + "attention_mask": [1, 1, 1], + "labels": [1, 20, 30], + } + ] + ), + True, + ), + ], +) +def test_is_pretokenized_dat(data, result): + """Ensure that the correct collator type is fetched based on the data args""" + assert is_pretokenized_dataset(data=data) == result # Tests for validating data args @@ -197,6 +213,14 @@ def test_get_trainer_kwargs_with_custom_masking(use_validation_data): ), False, ), + # data formatter with no response template + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA, + data_formatter_template="### Input: {{input}} \n\n### Response: {{output}}", + ), + False, + ), # response template with no dataset_text_field or formatter ( configs.DataArguments( @@ -205,13 +229,94 @@ def test_get_trainer_kwargs_with_custom_masking(use_validation_data): ), False, ), + # JSON without input / output for no single sequence arguments + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA, + ), + False, + ), + # pretokenized dataset with dataset_text_field + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_TOKENIZED, + dataset_text_field="output", + ), + False, + ), + # pretokenized dataset with data formatter + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_TOKENIZED, + data_formatter_template="### Input: {{input}} \n\n### Response: {{output}}", + ), + False, + ), + # pretokenized dataset with response template + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_TOKENIZED, + response_template="\n### Label:", + ), + False, + ), + # pretokenized training dataset with validation data not pretokenized + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_TOKENIZED, + validation_data_path=TWITTER_COMPLAINTS_DATA, + ), + False, + ), + # pretokenized data with dataset_text_field and response template + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_TOKENIZED, + response_template="\n### Label:", + dataset_text_field="output", + ), + False, + ), + # pretokenized data with packing to True + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_TOKENIZED, + ), + True, + ), ], ) def test_validate_args(data_args, packing): + """Ensure that respective errors are thrown for incorrect data arguments""" with pytest.raises(ValueError): validate_data_args(data_args, packing) +@pytest.mark.parametrize( + "data_args, packing", + [ + # pretokenized train dataset and no validation dataset passed + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_TOKENIZED, + ), + False, + ), + # pretokenized train and validation datasets + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_TOKENIZED, + validation_data_path=TWITTER_COMPLAINTS_TOKENIZED, + ), + False, + ), + ], +) +def test_validate_args_pretokenized(data_args, packing): + """Ensure that supported data args do not error out when passing pretokenized datasets""" + validate_data_args(data_args, packing) + + @pytest.mark.parametrize( "data_path, dataset_text_field, data_formatter_template", [ @@ -255,12 +360,57 @@ def test_get_formatted_dataset_with_single_sequence( data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", ) ), + # input/output JSON with masking on input + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT, + validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT, + ) + ), ], ) def test_format_dataset(data_args): + """Ensure that the train/eval data are properly formatted based on the data args / text field""" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - train_set, eval_set, dataset_text_field = format_dataset(data_args, tokenizer) + train_set, eval_set, dataset_text_field = format_dataset( + data_args, tokenizer, max_seq_length=1024 + ) assert isinstance(train_set, Dataset) assert isinstance(eval_set, Dataset) - assert dataset_text_field in train_set.column_names - assert dataset_text_field in eval_set.column_names + if dataset_text_field is None: + column_names = set(["input_ids", "attention_mask", "labels"]) + assert set(eval_set.column_names) == column_names + assert set(train_set.column_names) == column_names + else: + assert dataset_text_field in train_set.column_names + assert dataset_text_field in eval_set.column_names + + +@pytest.mark.parametrize( + "data_args", + [ + # pretokenized train and validation datasets + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_TOKENIZED, + validation_data_path=TWITTER_COMPLAINTS_TOKENIZED, + ) + ), + # pretokenized train datasets + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_TOKENIZED, + ) + ), + ], +) +def test_format_dataset_pretokenized(data_args): + """Ensure that pretokenized datasets are loaded and returned as is""" + train_set, eval_set, _ = format_dataset(data_args, None, max_seq_length=1024) + assert isinstance(train_set, Dataset) + if eval_set: + assert isinstance(eval_set, Dataset) + + assert set(["input_ids", "labels"]).issubset(set(train_set.column_names)) + if eval_set: + assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names)) diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 92fb4f8f8..0db5b518e 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -51,15 +51,15 @@ class ModelArguments: tokenizer_name_or_path: Optional[str] = field( default=None, metadata={ - "help": "Path to custom tokenizer.\ - If not provided it defaults to model_name_or_path" + "help": "Path to custom tokenizer. \ + If not provided it defaults to model_name_or_path \ + and special tokens will be added as needed for specific tokenizer classes. \ + For prompt tuning, if tokenizer_name_or_path provided, special tokens are not added, \ + otherwise, it defaults to model_name_or_path with special tokens for specific \ + tokenizer classes." }, ) - def __post_init__(self): - if not self.tokenizer_name_or_path: - self.tokenizer_name_or_path = self.model_name_or_path - @dataclass class DataArguments: @@ -97,6 +97,7 @@ class DataArguments: @dataclass class TrainingArguments(transformers.TrainingArguments): + # pylint: disable=too-many-instance-attributes cache_dir: Optional[str] = field(default=None) # optim: str = field(default=DEFAULT_OPTIMIZER) max_seq_length: int = field( @@ -119,6 +120,13 @@ class TrainingArguments(transformers.TrainingArguments): 'steps' (save is done every `save_steps`)" }, ) + save_model_dir: str = field( + default=None, + metadata={ + "help": "Directory where tuned model will be saved to \ + using SFTTrainer.save_model()." + }, + ) logging_strategy: str = field( default="epoch", metadata={ @@ -136,6 +144,15 @@ class TrainingArguments(transformers.TrainingArguments): + "Requires additional configs, see tuning.configs/tracker_configs.py" }, ) + log_level: str = field( + default="passive", + metadata={ + "help": "The log level to adopt during training. \ + By default, 'passive' level is set which keeps the \ + current log level for the Transformers library (which will be 'warning` by default) \ + Other possible values are 'debug', 'info', 'warning', 'error' and 'critical'" + }, + ) @dataclass diff --git a/tuning/config/peft_config.py b/tuning/config/peft_config.py index 5230e1652..6bf425159 100644 --- a/tuning/config/peft_config.py +++ b/tuning/config/peft_config.py @@ -46,7 +46,7 @@ class LoraConfig: r: int = 8 lora_alpha: int = 32 target_modules: List[str] = field( - default_factory=lambda: ["q_proj", "v_proj"], + default=None, metadata={ "help": "The names of the modules to apply LORA to. LORA selects modules which either \ completely match or " diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 0e360ad4f..45fea7ca6 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -16,6 +16,8 @@ from typing import Dict, List, Optional, Union import dataclasses import json +import logging +import os import sys import time import traceback @@ -33,7 +35,7 @@ LlamaTokenizerFast, TrainerCallback, ) -from transformers.utils import is_accelerate_available, logging +from transformers.utils import is_accelerate_available from trl import SFTConfig, SFTTrainer import fire import transformers @@ -60,6 +62,7 @@ USER_ERROR_EXIT_CODE, write_termination_log, ) +from tuning.utils.logging import set_log_level from tuning.utils.preprocessing_utils import ( format_dataset, get_data_collator, @@ -111,7 +114,7 @@ def train( fused_lora and fast_kernels must used together (may change in future). \ """ - logger = logging.get_logger("sft_trainer") + train_args, logger = set_log_level(train_args, "sft_trainer_train") # Validate parameters if (not isinstance(train_args.num_train_epochs, (float, int))) or ( @@ -160,7 +163,7 @@ def train( trainer_controller_args.trainer_controller_config_file is not None ): tc_callback = TrainerControllerCallback( - trainer_controller_args.trainer_controller_config_file + trainer_controller_args.trainer_controller_config_file, ) trainer_callbacks.append(tc_callback) @@ -190,32 +193,46 @@ def train( # TODO: Move these to a config as well tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name_or_path, cache_dir=train_args.cache_dir, use_fast=True + ( + model_args.tokenizer_name_or_path + if model_args.tokenizer_name_or_path + else model_args.model_name_or_path + ), + cache_dir=train_args.cache_dir, + use_fast=True, ) # Calculate and save additional metrics to track later. additional_metrics["model_load_time"] = time.time() - model_load_time peft_config = get_hf_peft_config( - task_type, peft_config, model_args.tokenizer_name_or_path + task_type, + peft_config, + ( + model_args.tokenizer_name_or_path + if model_args.tokenizer_name_or_path + else model_args.model_name_or_path + ), ) - # TODO: understand if we need to hardcode these here or just use defaults in model - if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)): - tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - "pad_token": "", - } - ) - elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)): - tokenizer.add_special_tokens( - { - "pad_token": "", - } - ) + # add special tokens only when a custom tokenizer is not passed + if not model_args.tokenizer_name_or_path: + # TODO: understand if we need to hardcode these here or just use defaults in model + if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)): + tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + "pad_token": "", + } + ) + elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)): + tokenizer.add_special_tokens( + { + "pad_token": "", + } + ) max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) logger.info("Max sequence length is %s", max_seq_length) @@ -228,20 +245,22 @@ def train( tokenizer.model_max_length, ) - # TODO: we need to change this, perhaps follow what open instruct does? + # add special tokens only when a custom tokenizer is not passed special_tokens_dict = {} - if tokenizer.pad_token is None: - logger.warning("PAD token set to default, missing in tokenizer") - special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN - if tokenizer.eos_token is None: - logger.warning("EOS token set to default, missing in tokenizer") - special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN - if tokenizer.bos_token is None: - logger.warning("BOS token set to default, missing in tokenizer") - special_tokens_dict["bos_token"] = configs.DEFAULT_BOS_TOKEN - if tokenizer.unk_token is None: - logger.warning("UNK token set to default, missing in tokenizer") - special_tokens_dict["unk_token"] = configs.DEFAULT_UNK_TOKEN + if not model_args.tokenizer_name_or_path: + # TODO: we need to change this, perhaps follow what open instruct does? + if tokenizer.pad_token is None: + logger.warning("PAD token set to default, missing in tokenizer") + special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN + if tokenizer.eos_token is None: + logger.warning("EOS token set to default, missing in tokenizer") + special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN + if tokenizer.bos_token is None: + logger.warning("BOS token set to default, missing in tokenizer") + special_tokens_dict["bos_token"] = configs.DEFAULT_BOS_TOKEN + if tokenizer.unk_token is None: + logger.warning("UNK token set to default, missing in tokenizer") + special_tokens_dict["unk_token"] = configs.DEFAULT_UNK_TOKEN # TODO: lower priority but understand if resizing impacts inference quality and why its needed. # It makes sense if we manipulate tokenizer that we also save it and provide it to inference. @@ -268,8 +287,14 @@ def train( formatted_train_dataset, formatted_validation_dataset, dataset_text_field, - ) = format_dataset(data_args, tokenizer) - data_collator = get_data_collator(packing, data_args.response_template, tokenizer) + ) = format_dataset(data_args, tokenizer, max_seq_length) + data_collator = get_data_collator( + packing, + data_args.response_template, + tokenizer, + formatted_train_dataset, + max_seq_length, + ) if framework is not None and framework.requires_agumentation: model, (peft_config,) = framework.augmentation( @@ -315,6 +340,7 @@ def train( try: for k, v in additional_metrics.items(): tracker.track(metric=v, name=k, stage="additional_metrics") + if exp_metadata: tracker.set_params(params=exp_metadata, name="experiment_metadata") except ValueError as e: logger.error( @@ -336,6 +362,31 @@ def train( trainer.train() + return trainer + + +def save(path: str, trainer: SFTTrainer, log_level="WARNING"): + """Saves model and tokenizer to given path. + + Args: + path: str + Path to save the model to. + trainer: SFTTrainer + Instance of SFTTrainer used for training to save the model. + """ + logger = logging.getLogger("sft_trainer_save") + # default value from TrainingArguments + if log_level == "passive": + log_level = "WARNING" + + logger.setLevel(log_level) + + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + + logger.info("Saving tuned model to path: %s", path) + trainer.save_model(path) + def get_parser(): """Get the command-line argument parser.""" @@ -456,11 +507,8 @@ def parse_arguments(parser, json_config=None): def main(**kwargs): # pylint: disable=unused-argument - logger = logging.get_logger("__main__") - parser = get_parser() job_config = get_json_config() - logger.debug("Input args parsed: %s", job_config) # accept arguments via command-line or JSON try: ( @@ -475,6 +523,10 @@ def main(**kwargs): # pylint: disable=unused-argument fusedops_kernels_config, exp_metadata, ) = parse_arguments(parser, job_config) + + # Function to set log level for python native logger and transformers training logger + training_args, logger = set_log_level(training_args, __name__) + logger.debug( "Input args parsed: \ model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \ @@ -493,7 +545,7 @@ def main(**kwargs): # pylint: disable=unused-argument exp_metadata, ) except Exception as e: # pylint: disable=broad-except - logging.error(traceback.format_exc()) + logger.error(traceback.format_exc()) write_termination_log( f"Exception raised during training. This may be a problem with your input: {e}" ) @@ -520,7 +572,7 @@ def main(**kwargs): # pylint: disable=unused-argument combined_tracker_configs.aim_config = aim_config try: - train( + trainer = train( model_args=model_args, data_args=data_args, train_args=training_args, @@ -557,6 +609,21 @@ def main(**kwargs): # pylint: disable=unused-argument write_termination_log(f"Unhandled exception during training: {e}") sys.exit(INTERNAL_ERROR_EXIT_CODE) + # save model + if training_args.save_model_dir: + try: + save( + path=training_args.save_model_dir, + trainer=trainer, + log_level=training_args.log_level, + ) + except Exception as e: # pylint: disable=broad-except + logger.error(traceback.format_exc()) + write_termination_log( + f"Failed to save model to {training_args.save_model_dir}: {e}" + ) + sys.exit(INTERNAL_ERROR_EXIT_CODE) + if __name__ == "__main__": fire.Fire(main) diff --git a/tuning/trackers/aimstack_tracker.py b/tuning/trackers/aimstack_tracker.py index bc2f8364d..139fbfc28 100644 --- a/tuning/trackers/aimstack_tracker.py +++ b/tuning/trackers/aimstack_tracker.py @@ -14,11 +14,11 @@ # Standard import json +import logging import os # Third Party from aim.hugging_face import AimCallback # pylint: disable=import-error -from transformers.utils import logging # Local from .tracker import Tracker @@ -63,8 +63,10 @@ def on_init_end(self, args, state, control, **kwargs): Args: For the arguments see reference to transformers.TrainingCallback """ - # pylint: disable=unused-argument - self.setup() # initialize aim's run_hash + super().on_init_end(args, state, control, **kwargs) + + if not self._run: + return # Change default run hash path to output directory if not specified if self.run_id_export_path is None: @@ -97,7 +99,8 @@ def __init__(self, tracker_config: AimConfig): information about the repo or the server and port where aim db is present. """ super().__init__(name="aim", tracker_config=tracker_config) - self.logger = logging.get_logger("aimstack_tracker") + # Get logger with root log level + self.logger = logging.getLogger() def get_hf_callback(self): """Returns the aim.hugging_face.AimCallback object associated with this tracker. @@ -172,7 +175,9 @@ def set_params(self, params, name="extra_params"): Raises: ValueError: the params passed is None or not of type dict """ - if params is None or (not isinstance(params, dict)): + if params is None: + return + if not isinstance(params, dict): raise ValueError( "set_params passed to aimstack should be called with a dict of params" ) diff --git a/tuning/trackers/filelogging_tracker.py b/tuning/trackers/filelogging_tracker.py index 213377d96..133687866 100644 --- a/tuning/trackers/filelogging_tracker.py +++ b/tuning/trackers/filelogging_tracker.py @@ -15,11 +15,11 @@ # Standard from datetime import datetime import json +import logging import os # Third Party from transformers import TrainerCallback -from transformers.utils import logging # Local from .tracker import Tracker @@ -80,7 +80,8 @@ def __init__(self, tracker_config: FileLoggingTrackerConfig): which contains the location of file where logs are recorded. """ super().__init__(name="file_logger", tracker_config=tracker_config) - self.logger = logging.get_logger("file_logging_tracker") + # Get logger with root log level + self.logger = logging.getLogger() def get_hf_callback(self): """Returns the FileLoggingCallback object associated with this tracker. diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index 98771c143..096099306 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -14,18 +14,15 @@ # Standard import dataclasses +import logging # Third Party -from transformers.utils import logging from transformers.utils.import_utils import _is_package_available # Local from .filelogging_tracker import FileLoggingTracker from tuning.config.tracker_configs import FileLoggingTrackerConfig, TrackerConfigFactory -logger = logging.get_logger("tracker_factory") - - # Information about all registered trackers AIMSTACK_TRACKER = "aim" FILE_LOGGING_TRACKER = "file_logger" @@ -54,9 +51,9 @@ def _register_aim_tracker(): AimTracker = _get_tracker_class(AimStackTracker, AimConfig) REGISTERED_TRACKERS[AIMSTACK_TRACKER] = AimTracker - logger.info("Registered aimstack tracker") + logging.info("Registered aimstack tracker") else: - logger.info( + logging.info( "Not registering Aimstack tracker due to unavailablity of package.\n" "Please install aim if you intend to use it.\n" "\t pip install aim" @@ -72,14 +69,14 @@ def _is_tracker_installed(name): def _register_file_logging_tracker(): FileTracker = _get_tracker_class(FileLoggingTracker, FileLoggingTrackerConfig) REGISTERED_TRACKERS[FILE_LOGGING_TRACKER] = FileTracker - logger.info("Registered file logging tracker") + logging.info("Registered file logging tracker") # List of Available Trackers # file_logger - Logs loss to a file # aim - Aimstack Tracker def _register_trackers(): - logger.info("Registering trackers") + logging.info("Registering trackers") if AIMSTACK_TRACKER not in REGISTERED_TRACKERS: _register_aim_tracker() if FILE_LOGGING_TRACKER not in REGISTERED_TRACKERS: @@ -142,7 +139,7 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory): e = "Requested Tracker {} not found. List trackers available for use is - {} ".format( name, available ) - logger.error(e) + logging.error(e) raise ValueError(e) meta = REGISTERED_TRACKERS[name] diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index b7cd005b5..fad1bbf70 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -59,11 +59,14 @@ CONTROLLER_CONFIG_KEY = "config" CONTROLLER_PATIENCE_CONFIG_KEY = "patience" CONTROLLER_TRIGGERS_KEY = "triggers" +CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL = "trigger_log_level" CONTROLLER_OPERATIONS_KEY = OPERATIONS_KEY -# Default operations / metrics to register +# Default values DEFAULT_OPERATIONS = {"operations": [{"name": "hfcontrols", "class": "HFControls"}]} DEFAULT_METRICS = {} +DEFAULT_CONFIG = {} +DEFAULT_TRIGGER_LOG_LEVEL = "debug" # pylint: disable=too-many-instance-attributes class TrainerControllerCallback(TrainerCallback): @@ -250,14 +253,14 @@ def _take_control_actions(self, event_name: str, **kwargs): continue if rule_succeeded: for operation_action in control_action.operation_actions: - logger.info( - "Taking [%s] action in controller [%s]", - operation_action.action, - control_action.name, - ) operation_action.instance.act( action=operation_action.action, + log_level=control_action.config[ + CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL + ], event_name=event_name, + control_name=control_action.name, + **self.metrics, **kwargs, ) @@ -302,6 +305,7 @@ def on_init_end( kwargs["state"] = state kwargs["control"] = control + log_levels = logging.get_log_levels_dict() # Check if there any metrics listed in the configuration if ( CONTROLLER_METRICS_KEY not in self.trainer_controller_config @@ -398,8 +402,24 @@ def on_init_end( ), operation_actions=[], ) + config_log_level_str = DEFAULT_TRIGGER_LOG_LEVEL if CONTROLLER_CONFIG_KEY in controller: control.config = controller[CONTROLLER_CONFIG_KEY] + config_log_level_str = control.config.get( + CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL, config_log_level_str + ) + if config_log_level_str not in log_levels: + logger.warning( + "Incorrect trigger log-level [%s] specified in the config." + " Defaulting to 'debug' level", + config_log_level_str, + ) + config_log_level_str = DEFAULT_TRIGGER_LOG_LEVEL + else: + control.config = DEFAULT_CONFIG + control.config[CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL] = log_levels[ + config_log_level_str + ] if CONTROLLER_PATIENCE_CONFIG_KEY in controller: control.patience = PatienceControl( **controller[CONTROLLER_PATIENCE_CONFIG_KEY] @@ -548,3 +568,59 @@ def on_evaluate( kwargs["state"] = state kwargs["control"] = control self._actions_on_event(event_name="on_evaluate", **kwargs) + + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_save", **kwargs) + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_step_begin", **kwargs) + + def on_optimizer_step( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_optimizer_step", **kwargs) + + def on_substep_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_substep_end", **kwargs) diff --git a/tuning/trainercontroller/controllermetrics/__init__.py b/tuning/trainercontroller/controllermetrics/__init__.py index 1f9f76705..6a8165852 100644 --- a/tuning/trainercontroller/controllermetrics/__init__.py +++ b/tuning/trainercontroller/controllermetrics/__init__.py @@ -23,6 +23,7 @@ from .history_based_metrics import HistoryBasedMetric from .loss import Loss from .trainingstate import TrainingState +from tuning.trainercontroller.controllermetrics.per_process_state import PerProcessState # List of metric handlers handlers = [] @@ -39,6 +40,7 @@ def register(cl: Type): # Register the default metric handlers in this package here register(TrainingState) +register(PerProcessState) register(EvalMetrics) register(Loss) register(HistoryBasedMetric) diff --git a/tuning/trainercontroller/controllermetrics/eval_metrics.py b/tuning/trainercontroller/controllermetrics/eval_metrics.py index 696714437..a87772674 100644 --- a/tuning/trainercontroller/controllermetrics/eval_metrics.py +++ b/tuning/trainercontroller/controllermetrics/eval_metrics.py @@ -18,14 +18,9 @@ # Standard from typing import Any -# Third Party -from transformers.utils import logging - # Local from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler -logger = logging.get_logger(__name__) - class EvalMetrics(MetricHandler): """Implements the controller metric which exposes the evaluation metrics""" diff --git a/tuning/trainercontroller/controllermetrics/history_based_metrics.py b/tuning/trainercontroller/controllermetrics/history_based_metrics.py index ae547d3c6..f66d634e5 100644 --- a/tuning/trainercontroller/controllermetrics/history_based_metrics.py +++ b/tuning/trainercontroller/controllermetrics/history_based_metrics.py @@ -21,12 +21,10 @@ # Third Party from transformers import TrainerState -from transformers.utils import logging # Local from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler -logger = logging.get_logger(__name__) METRICS_KEY = "metrics" LOG_LOSS_KEY = "loss" TRAINING_LOSS_KEY = "training_loss" diff --git a/tuning/trainercontroller/controllermetrics/loss.py b/tuning/trainercontroller/controllermetrics/loss.py index 2fd450148..543d6395c 100644 --- a/tuning/trainercontroller/controllermetrics/loss.py +++ b/tuning/trainercontroller/controllermetrics/loss.py @@ -61,4 +61,4 @@ def compute(self, state: TrainerState = None, **kwargs) -> Any: log = state.log_history[i] if "loss" not in log: continue - return float(log["loss"]) + return log diff --git a/tuning/trainercontroller/controllermetrics/metrics.yaml b/tuning/trainercontroller/controllermetrics/metrics.yaml new file mode 100644 index 000000000..d3a8a32de --- /dev/null +++ b/tuning/trainercontroller/controllermetrics/metrics.yaml @@ -0,0 +1,9 @@ +controller-metrics: + - name: loss + class: Loss + - name: state + class: TrainingState + - name: eval_metrics + class: EvalMetrics + - name: per_process_state + class: PerProcessState diff --git a/tuning/trainercontroller/controllermetrics/per_process_state.py b/tuning/trainercontroller/controllermetrics/per_process_state.py new file mode 100644 index 000000000..58a96de37 --- /dev/null +++ b/tuning/trainercontroller/controllermetrics/per_process_state.py @@ -0,0 +1,77 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from typing import Any + +# Third Party +from transformers import TrainerState +import torch + +# Local +from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler + + +class PerProcessState(MetricHandler): + """Implements the controller metric which exposes the per process state""" + + def __init__(self, **kwargs): + """Initializes the metric handler, by registering the event \ + list and arguments with base handler. + + Args: + kwargs: List of arguments (key, value)-pairs + """ + super().__init__( + events=[ + "on_init_end", + "on_step_end", + "on_epoch_begin", + "on_epoch_end", + "on_prediction_step", + "on_predict", + "on_log", + "on_train_end", + "on_train_begin", + "on_evaluate", + "on_save", + ], + **kwargs, + ) + + def validate(self) -> bool: + """Validate the training arguments (e.g logging_steps) are \ + compatible with the computation of this metric. + + Returns: + bool + """ + return True + + def compute(self, _: TrainerState = None, **kwargs) -> Any: + """Exposes the trainer state. + + Args: + state: TrainerState object + kwargs: Remaining event arguments + + Returns: + dict. Trainer state as a dictionary + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return {"rank": torch.distributed.get_rank()} + return {"rank": None} diff --git a/tuning/trainercontroller/controllermetrics/trainingstate.py b/tuning/trainercontroller/controllermetrics/trainingstate.py index 59ab3638c..8dc276339 100644 --- a/tuning/trainercontroller/controllermetrics/trainingstate.py +++ b/tuning/trainercontroller/controllermetrics/trainingstate.py @@ -21,10 +21,13 @@ # Third Party from transformers import TrainerState +from transformers.utils import logging # Local from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler +logger = logging.get_logger(__name__) + class TrainingState(MetricHandler): """Implements the controller metric which exposes the trainer state""" @@ -49,7 +52,7 @@ def __init__(self, **kwargs): "on_train_begin", "on_evaluate", ], - **kwargs + **kwargs, ) def validate(self) -> bool: diff --git a/tuning/trainercontroller/operations/__init__.py b/tuning/trainercontroller/operations/__init__.py index 99456d7ec..c0253d8f4 100644 --- a/tuning/trainercontroller/operations/__init__.py +++ b/tuning/trainercontroller/operations/__init__.py @@ -3,6 +3,7 @@ # Local from .hfcontrols import HFControls +from .logcontrol import LogControl from .operation import Operation # List of operation handlers @@ -20,3 +21,4 @@ def register(cl: Type): # Register the default operation handlers in this package here register(HFControls) +register(LogControl) diff --git a/tuning/trainercontroller/operations/hfcontrols.py b/tuning/trainercontroller/operations/hfcontrols.py index 2bba9a1d2..0548b4c12 100644 --- a/tuning/trainercontroller/operations/hfcontrols.py +++ b/tuning/trainercontroller/operations/hfcontrols.py @@ -1,17 +1,15 @@ # Standard from dataclasses import fields import inspect +import logging import re # Third Party from transformers import TrainerControl -from transformers.utils import logging # Local from .operation import Operation -logger = logging.get_logger(__name__) - class HFControls(Operation): """Implements the control actions for the HuggingFace controls in @@ -29,7 +27,7 @@ def __init__(self, **kwargs): for control_field in fields(TrainerControl): if re.search(r"^should_.+", control_field.name) is not None: setattr(self, control_field.name, self.control_action) - super().__init__() + super().__init__(**kwargs) def control_action(self, control: TrainerControl, **kwargs): """This method peeks into the stack-frame of the caller to get the action the triggered @@ -39,7 +37,7 @@ def control_action(self, control: TrainerControl, **kwargs): control: TrainerControl. Data class for controls. kwargs: List of arguments (key, value)-pairs """ - logger.debug("Arguments passed to control_action: %s", repr(kwargs)) + logging.debug("Arguments passed to control_action: %s", repr(kwargs)) frame_info = inspect.currentframe().f_back arg_values = inspect.getargvalues(frame_info) setattr(control, arg_values.locals["action"], True) diff --git a/tuning/trainercontroller/operations/logcontrol.py b/tuning/trainercontroller/operations/logcontrol.py new file mode 100644 index 000000000..385de3b4d --- /dev/null +++ b/tuning/trainercontroller/operations/logcontrol.py @@ -0,0 +1,55 @@ +# Third Party +from transformers import TrainingArguments +from transformers.utils import logging + +# Local +from .operation import Operation + +logger = logging.get_logger(__name__) +logger.setLevel(level=logging.DEBUG) + + +class LogControl(Operation): + """Operation that can be used to log useful information on specific events.""" + + def __init__(self, log_format: str, log_level: str, **kwargs): + """Initializes the HuggingFace controls. In this init, the fields with `should_` of the + transformers.TrainerControl data class are extracted, and for each of those fields, the + control_action() method's pointer is set, and injected as a class member function. + + Args: + kwargs: List of arguments (key, value)-pairs + """ + log_levels = logging.get_log_levels_dict() + if log_level not in log_levels: + raise ValueError( + "Specified log_level [%s] is invalid for LogControl" % (log_level) + ) + self.log_level = log_levels[log_level] + self.log_format = log_format + super().__init__(**kwargs) + + def should_log( + self, + event_name: str = None, + control_name: str = None, + args: TrainingArguments = None, + **kwargs, + ): + """This method peeks into the stack-frame of the caller to get the action the triggered + a call to it. Using the name of the action, the value of the control is set. + + Args: + control: TrainerControl. Data class for controls. + kwargs: List of arguments (key, value)-pairs + """ + log_msg = self.log_format.format( + event_name=event_name, + control_name=control_name, + args=args, + **kwargs, + ) + logger.log( + self.log_level, + log_msg, + ) diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index 916420e81..70805a015 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -3,22 +3,39 @@ import inspect import re +# Third Party +from transformers.utils import logging + +logger = logging.get_logger(__name__) + class Operation(metaclass=abc.ABCMeta): """Base class for operations""" - def __init__(self): + def __init__(self, name: str, **kwargs): """Initializes the HuggingFace controls. In this init, we follow the convention that every action should preceed with prefix `should_`. If so, it is treated as a valid action. """ + self._name = name + self.kwargs = kwargs self.valid_actions = {} + self.name = name + self.kwargs = kwargs for action_name, action_method in inspect.getmembers( self, predicate=inspect.ismethod ): if re.search(r"^should_.+", action_name) is not None: self.valid_actions[action_name] = action_method + def get_name(self) -> str: + """Returns the name of the operation. + + Returns: + str + """ + return self._name + def validate(self, action: str) -> bool: """Validates the action by checking if it valid action or not. @@ -30,15 +47,34 @@ def validate(self, action: str) -> bool: """ return action in self.valid_actions - def act(self, action: str, **kwargs): + def act( + self, + action: str, + log_level: int, + event_name: str = None, + control_name: str = None, + **kwargs, + ): """Validates the action and invokes it. Args: action: str. String depicting the action. + event_name: str. Event name triggering the act. + control_name: str. Name of the controller defining the act. + log_level: int. Log level for triggering the log. kwargs: List of arguments (key, value)-pairs. """ if not self.validate(action): raise ValueError(f"Invalid operation {action}") + logger.log( + log_level, + "Taking [%s] action in controller [%s] triggered at event [%s]", + action, + control_name, + event_name, + ) + kwargs["event_name"] = event_name + kwargs["control_name"] = control_name self.valid_actions[action](**kwargs) def get_actions(self) -> list[str]: diff --git a/tuning/trainercontroller/patience.py b/tuning/trainercontroller/patience.py index b8098fdf0..ecdb0699a 100644 --- a/tuning/trainercontroller/patience.py +++ b/tuning/trainercontroller/patience.py @@ -15,8 +15,8 @@ # SPDX-License-Identifier: Apache-2.0 # https://spdx.dev/learn/handling-license-info/ -# Third Party -from transformers.utils import logging +# Standard +import logging # Resets the patience if the rule outcome happens to be false. # Here, the expectation is to have unbroken "True"s for patience @@ -31,8 +31,6 @@ # will be exceeded afer the fifth event. MODE_NO_RESET_ON_FAILURE = "no_reset_on_failure" -logger = logging.get_logger(__name__) - class PatienceControl: """Implements the patience control for every rule""" @@ -51,7 +49,7 @@ def should_tolerate( elif self._mode == MODE_RESET_ON_FAILURE: self._patience_counter = 0 if self._patience_counter <= self._patience_threshold: - logger.debug( + logging.debug( "Control {} triggered on event {}: " "Enforcing patience [patience_counter = {:.2f}, " "patience_threshold = {:.2f}]".format( @@ -62,7 +60,7 @@ def should_tolerate( ) ) return True - logger.debug( + logging.debug( "Control {} triggered on event {}: " "Exceeded patience [patience_counter = {:.2f}, " "patience_threshold = {:.2f}]".format( diff --git a/tuning/utils/data_type_utils.py b/tuning/utils/data_type_utils.py index cefebb100..52bae6d77 100644 --- a/tuning/utils/data_type_utils.py +++ b/tuning/utils/data_type_utils.py @@ -14,13 +14,11 @@ # Standard from typing import Union +import logging # Third Party -from transformers.utils import logging import torch -logger = logging.get_logger("data_utils") - def str_to_torch_dtype(dtype_str: str) -> torch.dtype: """Given a string representation of a Torch data type, convert it to the actual torch dtype. @@ -35,7 +33,7 @@ def str_to_torch_dtype(dtype_str: str) -> torch.dtype: """ dt = getattr(torch, dtype_str, None) if not isinstance(dt, torch.dtype): - logger.error(" ValueError: Unrecognized data type of a torch.Tensor") + logging.error(" ValueError: Unrecognized data type of a torch.Tensor") raise ValueError("Unrecognized data type of a torch.Tensor") return dt diff --git a/tuning/utils/logging.py b/tuning/utils/logging.py new file mode 100644 index 000000000..1f1b6c73e --- /dev/null +++ b/tuning/utils/logging.py @@ -0,0 +1,64 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import logging +import os + + +def set_log_level(train_args, logger_name=None): + """Set log level of python native logger and TF logger via argument from CLI or env variable. + + Args: + train_args + Training arguments for training model. + logger_name + Logger name with which the logger is instantiated. + + Returns: + train_args + Updated training arguments for training model. + train_logger + Logger with updated effective log level + """ + + # Clear any existing handlers if necessary + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + # Configure Python native logger and transformers log level + # If CLI arg is passed, assign same log level to python native logger + log_level = "WARNING" + if train_args.log_level != "passive": + log_level = train_args.log_level + + # If CLI arg not is passed and env var LOG_LEVEL is set, + # assign same log level to both logger + elif os.environ.get("LOG_LEVEL"): + log_level = os.environ.get("LOG_LEVEL") + train_args.log_level = ( + log_level.lower() + if not os.environ.get("TRANSFORMERS_VERBOSITY") + else os.environ.get("TRANSFORMERS_VERBOSITY") + ) + + logging.basicConfig( + format="%(levelname)s:%(filename)s:%(message)s", level=log_level.upper() + ) + + if logger_name: + train_logger = logging.getLogger(logger_name) + else: + train_logger = logging.getLogger() + return train_args, train_logger diff --git a/tuning/utils/preprocessing_utils.py b/tuning/utils/preprocessing_utils.py index 88db911a6..68b2755d8 100644 --- a/tuning/utils/preprocessing_utils.py +++ b/tuning/utils/preprocessing_utils.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # Standard -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Union import json +import logging # Third Party -from datasets import Dataset +from datasets import Dataset, IterableDataset +from datasets.exceptions import DatasetGenerationError from transformers import AutoTokenizer, DataCollatorForSeq2Seq -from transformers.utils import logging from trl import DataCollatorForCompletionOnlyLM import datasets @@ -26,7 +27,24 @@ from tuning.config import configs from tuning.utils.data_utils import apply_custom_formatting_template -logger = logging.get_logger("sft_trainer_preprocessing") +# In future we may make the fields configurable +JSON_INPUT_KEY = "input" +JSON_OUTPUT_KEY = "output" + + +# check if the provided dataset is pretokenized or not +# the check is taken from trl +# https://github.com/huggingface/trl/blob/ddf4c8dc3ecf6d9ee2b24f94c62182ffd682c808/trl/trainer/sft_trainer.py#L498-L509 +def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]): + if not data: + return False + if isinstance(data, str): + try: + data = datasets.load_dataset("json", data_files=data, split="train[:1]") + except DatasetGenerationError as e: + raise DatasetGenerationError("failed to load the provided dataset") from e + + return ("input_ids" in data.column_names) and ("labels" in data.column_names) def validate_data_args(data_args: configs.DataArguments, packing: bool): @@ -35,22 +53,48 @@ def validate_data_args(data_args: configs.DataArguments, packing: bool): data_args.training_data_path, str ), "Training data path has to be set and str" - # Dataset containing single sequence needs a response template for masking - if data_args.response_template is None and data_args.dataset_text_field is not None: - if packing is False: + is_train_data_pretokenized = is_pretokenized_dataset(data_args.training_data_path) + is_eval_data_pretokenized = is_pretokenized_dataset(data_args.validation_data_path) + + ### Data format 1 + # if the provided train dataset is pretokenized + # however user provides formatting flags, error out + if is_train_data_pretokenized: + if ( + data_args.response_template + or data_args.data_formatter_template + or data_args.dataset_text_field + ): raise ValueError( - "Since dataset_text_field is provided and packing is disabled, \ - needs a corresponding response template for masking" + "fields response_template, data_formatter_template, and dataset_text_field \ + are not applicable for pretokenized datasets" ) - # Currently if packing is false, we require a response_template. This may change in future. - if packing is False: - if data_args.response_template is None: + # if the train dataset is pretokenized + # ensure validation dataset is pretokenized otherwise error out + if data_args.validation_data_path and not is_eval_data_pretokenized: raise ValueError( - "Response template is None, needs to be set for training \ - with packing disabled." + "validation data should be pretokenized to be used \ + along with pretokenized train data" ) + # packing wont be available for pretokenized datasets in trl library + # see: https://github.com/huggingface/trl/issues/1848 + if packing: + raise ValueError("packing will not be used when datasets are pretokenized") + return + + ### Data format 2 + # Dataset containing single sequence needs a response template for masking + if data_args.dataset_text_field or data_args.data_formatter_template: + if data_args.response_template is None: + if packing is False: + raise ValueError( + "Since dataset_text_field or data_formatter_template \ + is provided and packing is disabled, \ + needs a corresponding response template for masking" + ) + if data_args.response_template: # To use Response template, pass datasets with single sequence instances \ # or a formatter template to create single sequence on the fly. @@ -65,16 +109,31 @@ def validate_data_args(data_args: configs.DataArguments, packing: bool): "dataset_text_field and data_formatter_template are both set,\ but are mutually exclusive options" ) - # TODO(s) In future seupport two more formats: - # 1. Allow no response template, and JSON with input/output fields and mask input - # 2. Allow pretokenized Dataset besides JSON. + ### Data format 3 + # If not single sequence, JSON should contain input/output fields + if not (data_args.dataset_text_field or data_args.data_formatter_template): + json_dataset = datasets.load_dataset( + "json", data_files=data_args.training_data_path + ) + if JSON_INPUT_KEY not in json_dataset["train"].column_names: + raise ValueError( + "JSON should contain input field if no dataset_text_field or \ + data_formatter_template specified" + ) + if JSON_OUTPUT_KEY not in json_dataset["train"].column_names: + raise ValueError( + "JSON should contain output field if no dataset_text_field or \ + data_formatter_template specified" + ) def get_data_collator( packing: bool, response_template: Optional[str], tokenizer: AutoTokenizer, + formatted_train_dataset: Dataset, + max_seq_length: int, ) -> Callable: """Create and return the the appropriate collator type based on the configuration for packing, response_template, and dataset_text_field. @@ -86,11 +145,17 @@ def get_data_collator( Response template to be used for formatting by TRL. tokenizer: AutoTokenizer Loaded tokenizer object to be used by the collator. + formatted_train_dataset: Dataset + Train Dataset formatted for tuning + max_seq_length: int + Max sequence length expected Returns: Callable Callable collator to be leveraged by the trainer. """ + is_train_data_pretokenized = is_pretokenized_dataset(formatted_train_dataset) + if not packing: # TODO: near term - how response template ids are parsed out needs to be cleaned. # The [2:] here applies if response template has \n prefix, it is needed to strip \n, @@ -105,30 +170,44 @@ def get_data_collator( tokenizer=tokenizer, ignore_index=configs.IGNORE_INDEX, ) - # TO DO with future changes, - # 1. Support no packing and seq2seq colator without response template - # # if dataset_text_field is None and response_template is None: - # # Use the seq2seq data collator; - # # Note that this automatically pads labels with -100 - # return DataCollatorForSeq2Seq( - # tokenizer=tokenizer, padding=True, max_length=max_sequence_length - # ) - # 2. add anything needed for preprocessed input + # Note that this automatically pads labels with -100 + # TODO check if this is sufficient for preprocessed + if is_train_data_pretokenized: + return DataCollatorForSeq2Seq( + tokenizer=tokenizer, padding=True, max_length=max_seq_length + ) raise ValueError( "Could not pick a data collator. Please refer to supported data formats" ) -def format_dataset(data_args: configs.DataArguments, tokenizer: AutoTokenizer): +def format_dataset( + data_args: configs.DataArguments, tokenizer: AutoTokenizer, max_seq_length: int +): """ Args: data_args: tuning.config.configs.DataArguments tokenizer: AutoTokenizer + max_seq_length: int + Max sequence length expected Returns: Tuple(Dataset, Dataset, str) tuple containing train_dataset, eval_dataset and dataset_text_field """ eval_dataset = None + is_train_data_pretokenized = is_pretokenized_dataset(data_args.training_data_path) + + if is_train_data_pretokenized: + train_dataset = datasets.load_dataset( + "json", data_files=data_args.training_data_path, split="train" + ) + if data_args.validation_data_path: + eval_dataset = datasets.load_dataset( + "json", data_files=data_args.validation_data_path, split="train" + ) + # dataset_text_field is irrelevant to pretokenized datasets + return train_dataset, eval_dataset, None + dataset_text_field = data_args.dataset_text_field if data_args.data_formatter_template or dataset_text_field: if dataset_text_field is None: @@ -139,7 +218,7 @@ def format_dataset(data_args: configs.DataArguments, tokenizer: AutoTokenizer): tokenizer, data_args.data_formatter_template, ) - logger.info("Training dataset length is %s", len(train_dataset)) + logging.info("Training dataset length is %s", len(train_dataset)) if data_args.validation_data_path: (eval_dataset) = get_formatted_dataset_with_single_sequence( data_args.validation_data_path, @@ -147,133 +226,26 @@ def format_dataset(data_args: configs.DataArguments, tokenizer: AutoTokenizer): tokenizer, data_args.data_formatter_template, ) - logger.info("Validation dataset length is %s", len(eval_dataset)) - # TODO: add a else here for preprocessing - return train_dataset, eval_dataset, dataset_text_field - - -################################################################################### -### The functions below are not yet used. Iterative development towards new features - - -def get_data_collator_temp( - packing: bool, - dataset_text_field: Optional[str], - response_template: Optional[str], - max_sequence_length: int, - tokenizer: AutoTokenizer, -) -> Callable: - """Create and return the the appropriate collator type based on the configuration for packing, - response_template, and dataset_text_field. - - Args: - packing: bool - Whether or not we should apply packing or not. - dataset_text_field: Optional[str] - Dataset text field fto be used for formatting by TRL. - response_template: Optional[str] - Response template to be used for formatting by TRL. - max_sequence_length: int - Max sequence length to be used for sequence tokenization. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - - Returns: - Callable - Callable collator to be leveraged by the trainer. - """ - if not packing: - if dataset_text_field is None and response_template is None: - # Use the seq2seq data collator; note that this automatically pads labels with -100 - return DataCollatorForSeq2Seq( - tokenizer=tokenizer, padding=True, max_length=max_sequence_length - ) - # TODO: near term - how response template ids are parsed out needs to be cleaned. - # The [2:] here applies if response template has \n prefix, it is needed to strip \n, - # otherwise template is not found. We will create issue to clean this out after we discuss - # data formats and collators we will support. - response_template_ids = tokenizer.encode( - response_template, add_special_tokens=False - )[2:] - return DataCollatorForCompletionOnlyLM( - response_template=response_template_ids, - tokenizer=tokenizer, - ignore_index=configs.IGNORE_INDEX, - ) - - -def get_data_trainer_kwargs( - training_data_path: str, - validation_data_path: str, - packing: bool, - response_template: Optional[str], - max_sequence_length: int, - tokenizer: AutoTokenizer, - dataset_text_field: Optional[str], -) -> Dict[str, Any]: - """Get trainer args related to data / processing. At the moment, this consists of: - - the training dataset - - the evaluation dataset - - the data collator - - Maybe a formatting a function [only for a special case for validation] - The result can be kwarg expanded into the trainer initialization. - - Args: - training_data_path: str - Path to the training data. - validation_data_path: str - Path to the validation data. - packing: bool - Whether or not we should apply packing or not. - response_template: Optional[str] - Response template to be used for formatting by TRL. - max_sequence_length: int - Max sequence length to be used for sequence tokenization. - tokenizer: AutoTokenizer - Loaded tokenizer object to be used by the collator. - dataset_text_field: Optional[str] - Dataset text field fto be used for formatting by TRL. - - Returns: - Dict[str, Any] - Data related kwargs to be used by the SFT Trainer. - """ - data_collator = get_data_collator_temp( - packing, dataset_text_field, response_template, max_sequence_length, tokenizer - ) - eval_dataset = None - data_kwargs = {} - if isinstance(data_collator, DataCollatorForSeq2Seq): - # HACK: This function is never called, but is needed to sidestep TRL's internal validation. - data_kwargs["formatting_func"] = lambda x: x + logging.info("Validation dataset length is %s", len(eval_dataset)) + else: + # This is for JSON containing input/output fields train_dataset = get_preprocessed_dataset( - training_data_path, + data_args.training_data_path, tokenizer, - max_sequence_length, - input_field_name="input", - output_field_name="output", + max_seq_length, + input_field_name=JSON_INPUT_KEY, + output_field_name=JSON_OUTPUT_KEY, ) - if validation_data_path: + if data_args.validation_data_path: eval_dataset = get_preprocessed_dataset( - validation_data_path, + data_args.validation_data_path, tokenizer, - max_sequence_length, - input_field_name="input", - output_field_name="output", - ) - else: - train_dataset = get_formatted_dataset_with_single_sequence( - training_data_path, dataset_text_field, tokenizer - ) - if validation_data_path: - eval_dataset = get_formatted_dataset_with_single_sequence( - validation_data_path, dataset_text_field, tokenizer + max_seq_length, + input_field_name=JSON_INPUT_KEY, + output_field_name=JSON_OUTPUT_KEY, ) - data_kwargs["data_collator"] = data_collator - data_kwargs["train_dataset"] = train_dataset - data_kwargs["eval_dataset"] = eval_dataset - return data_kwargs + return train_dataset, eval_dataset, dataset_text_field def get_formatted_dataset_with_single_sequence( @@ -396,7 +368,7 @@ def get_jsonl_object(): ### Utils for custom masking / manipulating input / output strs, etc -def combine_sequence(input_element: str, output_element: str): +def combine_sequence(input_element: str, output_element: str, eos_token: str = ""): """Combines / concatenates input & output element. Args: @@ -404,6 +376,9 @@ def combine_sequence(input_element: str, output_element: str): Input component of the combined sequence. output_element: str Output component of the combined sequence. + eos_token: str + EOS token associated with the tokenizer. \ + If passed, it will be concatenated at end Returns: str @@ -412,8 +387,8 @@ def combine_sequence(input_element: str, output_element: str): if not input_element.endswith((" ", "\n", "\t")) and not output_element.startswith( (" ", "\n", "\t") ): - return input_element + " " + output_element - return input_element + output_element + return input_element + " " + output_element + eos_token + return input_element + output_element + eos_token def preprocess_and_tokenize( @@ -445,7 +420,7 @@ def preprocess_and_tokenize( Dictionary containing the input IDs/labels/attention mask for this record. """ combined_seq = combine_sequence( - element[input_field_name], element[output_field_name] + element[input_field_name], element[output_field_name], tokenizer.eos_token ) tokenized_comb_seqs = tokenizer( From 1107e0042e7e4fdfa8e60f8fb33a03327b2278f6 Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 3 Sep 2024 13:53:54 -0400 Subject: [PATCH 3/7] deps: Add protobuf to support ALLaM models (#328) undefined --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e31192470..d0d5e944f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "tqdm>=4.66.2,<5.0", "trl>=0.9.3,<1.0", "peft>=0.8.0,<0.13", +"protobuf>=5.28.0,<6.0.0", "datasets>=2.15.0,<3.0", "fire>=0.5.0,<1.0", "simpleeval>=0.9.13,<1.0", From 16543eed5f549a847f65e340dbe55c9c10ea6d4c Mon Sep 17 00:00:00 2001 From: Will Date: Tue, 3 Sep 2024 16:28:24 -0400 Subject: [PATCH 4/7] deps: set previous versions for accelerate and trl for patch release (#329) * deps: Downgrade accelerate for patch release as it caused issues Signed-off-by: Will Johnson * deps: Downgrade trl to version from fms-hf-tuning 1.2.1 Signed-off-by: Will Johnson * fix: use exact previous versions Signed-off-by: Will Johnson --------- Signed-off-by: Will Johnson --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d0d5e944f..9843a8274 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,13 +27,13 @@ classifiers=[ ] dependencies = [ "numpy>=1.26.4,<2.0", -"accelerate>=0.20.3,<0.40", +"accelerate==0.33", "transformers>4.41,<5.0", "torch>=2.2.0,<3.0", "sentencepiece>=0.1.99,<0.3", "tokenizers>=0.13.3,<1.0", "tqdm>=4.66.2,<5.0", -"trl>=0.9.3,<1.0", +"trl==0.9.6", "peft>=0.8.0,<0.13", "protobuf>=5.28.0,<6.0.0", "datasets>=2.15.0,<3.0", From 9b8245e74144f7ee73b7241a1687b6c77f0eb2e4 Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Tue, 1 Oct 2024 09:20:51 -0600 Subject: [PATCH 5/7] build(deps): unset hardcoded trl version to get latest updates (#358) Signed-off-by: Anh Uong --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 67623503d..2b63b8f54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "sentencepiece>=0.1.99,<0.3", "tokenizers>=0.13.3,<1.0", "tqdm>=4.66.2,<5.0", -"trl==0.9.6", +"trl>=0.9.3,<1.0", "peft>=0.8.0,<0.13", "protobuf>=5.28.0,<6.0.0", "datasets>=2.15.0,<3.0", From 8cf6795f1a343dd184e3ef8e199c1679063d70a3 Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Thu, 20 Mar 2025 14:11:48 -0600 Subject: [PATCH 6/7] deps: install mamba_ssm from package instead of github Signed-off-by: Anh Uong --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fe0726494..cf3688f33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ aim = ["aim>=3.19.0,<4.0"] mlflow = ["mlflow"] fms-accel = ["fms-acceleration>=0.6"] gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"] -mamba = ["mamba_ssm[causal-conv1d] @ git+https://github.com/state-spaces/mamba.git"] +mamba = ["mamba_ssm[causal-conv1d]>=2.0.0,<3.0.0"] scanner-dev = ["HFResourceScanner>=0.1.0"] From 29cbf0766fb2d6cf6d9466660fc5cd84f7f57a51 Mon Sep 17 00:00:00 2001 From: Dushyant Behl Date: Mon, 10 Nov 2025 22:38:21 +0530 Subject: [PATCH 7/7] tag peft to a rc instead of commit tag Signed-off-by: Dushyant Behl --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 10fc72b4e..e1f108552 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "tokenizers<=0.22", "tqdm>=4.66.2,<5.0", "trl>=0.19.1,<0.20.0", -"peft @ git+https://github.com/huggingface/peft.git@293aea5df6db240856a77f89955d1a89ce38b50d", +"peft==0.18.0rc0", "datasets>=4.0.0,<5.0.0", "simpleeval>=0.9.13,<2.0", "pillow>=11.0.0,<12.0",