Skip to content

Commit 5eeec0e

Browse files
authored
feat: ODM without categories integration with fms-accel (#641)
ODM without categories and unit tests Signed-off-by: romit <romit@ibm.com>
1 parent 717decc commit 5eeec0e

File tree

5 files changed

+114
-1
lines changed

5 files changed

+114
-1
lines changed

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_AND_SPLIT_YAML = os.path.join(
3535
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling_and_split.yaml"
3636
)
37+
DATA_CONFIG_SINGLE_DATASET_ODM_YAML = os.path.join(
38+
PREDEFINED_DATA_CONFIGS, "single_dataset_with_odm.yaml"
39+
)
3740
DATA_CONFIG_MULTIPLE_DATASETS_ODM_YAML = os.path.join(
3841
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_odm.yaml"
3942
)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
dataprocessor:
2+
type: odm
3+
sampling_stopping_strategy: first_exhausted # ignored
4+
seed: 66
5+
odm:
6+
update_interval: 1 # update every step
7+
sampling_interval: 1 # sample category for every sample
8+
reward_type: validation_loss # uses eval loss of each dataset as reward
9+
gamma: 0.1 # MAB hyper-parameter
10+
eta: 0.2 # MAB hyper-parameter
11+
auto_categorize_input_column: "input" # Required: Input field on which we need to apply clustering for forming pseudo categories
12+
auto_categorize_num_categories: "3" # Optional: Number of categories for clustering
13+
# if not provided, this will be inferred based on dataset size
14+
datasets:
15+
- name: dataset_1
16+
split:
17+
train: 0.8
18+
validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss.
19+
data_paths:
20+
- "FILE_PATH"
21+
data_handlers:
22+
- name: tokenize_and_apply_input_masking
23+
arguments:
24+
remove_columns: all
25+
batched: false
26+
fn_kwargs:
27+
input_column_name: input
28+
output_column_name: output

tests/test_sft_trainer.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
DATA_CONFIG_MULTITURN_GRANITE_3_1B_DATA_YAML,
5252
DATA_CONFIG_PRETOKENIZE_DATA_YAML,
5353
DATA_CONFIG_RENAME_SELECT_COLUMNS,
54+
DATA_CONFIG_SINGLE_DATASET_ODM_YAML,
5455
DATA_CONFIG_SKIP_LARGE_COLUMNS_HANDLER,
5556
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
5657
DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER,
@@ -2575,3 +2576,71 @@ def test_online_data_mixing_plugin_sample_training_no_validation_split(
25752576
"What length of trench,\n25 m broad and 15 m deep can be dug in 30 days ?"
25762577
in output_inference
25772578
), f"{output_inference} does not include the prompt"
2579+
2580+
2581+
@pytest.mark.skipif(
2582+
not is_fms_accelerate_available(plugins="odm"),
2583+
reason="Only runs if fms-accelerate is installed along with online-data-mixing plugin",
2584+
)
2585+
@pytest.mark.parametrize(
2586+
"datafile, datasetconfigname, reward_type",
2587+
[
2588+
(
2589+
NESTFUL_DATA_INPUT_OUTPUT_JSONL,
2590+
DATA_CONFIG_SINGLE_DATASET_ODM_YAML,
2591+
"entropy",
2592+
),
2593+
(
2594+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
2595+
DATA_CONFIG_SINGLE_DATASET_ODM_YAML,
2596+
"entropy",
2597+
),
2598+
],
2599+
)
2600+
def test_online_data_mixing_plugin_with_auto_categorization(
2601+
datafile, datasetconfigname, reward_type
2602+
):
2603+
"""Ensure fms_acceleration_odm plugin trains with autocategorization"""
2604+
with tempfile.TemporaryDirectory() as tempdir:
2605+
data_formatting_args = copy.deepcopy(DATA_ARGS)
2606+
2607+
# set training_data_path and response_template to none
2608+
data_formatting_args.response_template = None
2609+
data_formatting_args.training_data_path = None
2610+
2611+
# add data_paths in data_config file
2612+
with tempfile.NamedTemporaryFile(
2613+
"w", delete=False, suffix=".yaml"
2614+
) as temp_yaml_file:
2615+
with open(datasetconfigname, "r", encoding="utf-8") as f:
2616+
data = yaml.safe_load(f)
2617+
data["dataprocessor"]["odm"]["reward_type"] = reward_type
2618+
2619+
for d, df in zip(data["datasets"], [datafile]):
2620+
d["data_paths"] = [df]
2621+
2622+
yaml.dump(data, temp_yaml_file)
2623+
data_formatting_args.data_config_path = temp_yaml_file.name
2624+
2625+
train_args = copy.deepcopy(TRAIN_ARGS)
2626+
train_args.output_dir = tempdir
2627+
train_args.logging_strategy = "steps"
2628+
train_args.max_steps = 10
2629+
train_args.eval_strategy = "steps"
2630+
train_args.eval_steps = 1
2631+
2632+
sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args)
2633+
2634+
# validate full ft configs
2635+
_validate_training(tempdir)
2636+
_, checkpoint_path = _get_latest_checkpoint_trainer_state(tempdir)
2637+
2638+
# Load the model
2639+
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
2640+
2641+
# Run inference on the text
2642+
output_inference = loaded_model.run(
2643+
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
2644+
)
2645+
assert len(output_inference) > 0
2646+
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference

tuning/config/acceleration_configs/odm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
# Standard
1616
from dataclasses import dataclass
17-
from typing import Union
17+
from typing import Optional, Union
1818

1919
# Local
2020
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass
@@ -29,6 +29,10 @@ class ODM:
2929
gamma: float = 0.1
3030
eta: float = 0.1
3131
resume_from_checkpoint: Union[bool, str] = False
32+
auto_categorize_input_column: str = None
33+
auto_categorize_num_categories: Optional[int] = None
34+
auto_categorize_model_name: str = "Qwen/Qwen3-Embedding-0.6B"
35+
auto_categorize_batch_size: int = 64
3236

3337

3438
@dataclass

tuning/data/setup_dataprocessor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,14 @@ def setup_train_dataset_for_odm(
548548
processor=processor,
549549
)
550550

551+
auto_categorize_config = {}
552+
if hasattr(odm_config.odm, "auto_categorize_input_column"):
553+
auto_categorize_config = {
554+
"input_column": "input_ids",
555+
"num_categories": int(odm_config.odm.auto_categorize_num_categories),
556+
"tokenizer": tokenizer,
557+
}
558+
551559
train_dataset = OnlineMixingDataset(
552560
train_dataset,
553561
collators,
@@ -560,6 +568,7 @@ def setup_train_dataset_for_odm(
560568
sampling_interval=odm_config.odm.sampling_interval,
561569
eval_batch_size=train_args.per_device_eval_batch_size,
562570
reward_type=odm_config.odm.reward_type,
571+
auto_categorize_config=auto_categorize_config,
563572
)
564573
train_args.accelerator_config = {"dispatch_batches": False}
565574
return (True, train_dataset, data_collator)

0 commit comments

Comments
 (0)