diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/auto_categorizer.py b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/auto_categorizer.py index 337b49de..af75b0f1 100644 --- a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/auto_categorizer.py +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/auto_categorizer.py @@ -25,11 +25,11 @@ import math # Third Party -import torch from datasets import Dataset, DatasetDict from sentence_transformers import SentenceTransformer from sklearn.cluster import KMeans import numpy as np +import torch logger = getLogger(__name__) diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py index c7447a55..c4d12177 100644 --- a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py @@ -36,6 +36,7 @@ def __init__( output_dir="odm", reward_type=Reward.ENTROPY, auto_categorize_config: Optional[dict | AutoCategorizeConfig] = None, + seed: Optional[int] = 42, ): """Mixes datasets with sampling ratios learnt using Multi Armed Bandit (MAB) EXP3 and rewards defined. @@ -69,6 +70,8 @@ def __init__( configuration overrides for the auto-categorizer such as text column, embedding model, cluster count etc. This will only be used if the `dataset_dict` has only one key. + seed (Optional[int], optional): Base seed for the dataset-level RNG so all + distributed ranks iterate over the exact same sample order. Defaults to 42. """ self.auto_categorize = len(dataset_dict.keys()) == 1 self._auto_categorize_config = self._build_auto_categorize_config( @@ -190,6 +193,12 @@ def __init__( "action": "", # one of sample or update } + # Local RNG so every process can deterministically sample identical streams. + self.seed = 42 if seed is None else seed + self._rng = random.Random(self.seed) + self._current_epoch = 0 + self._rng_state_restored = False + def log_to_file(self, data: dict): """helper function to log the state to the file @@ -203,9 +212,17 @@ def log_to_file(self, data: dict): def __iter__(self): return self + def set_epoch(self, epoch: int): + """Ensures every process observes the same RNG state per epoch.""" + self._current_epoch = epoch + if self._rng_state_restored: + self._rng_state_restored = False + return + self._rng.seed(self.seed + epoch) + def __next__(self): if self.produced % self.sampling_interval == 0: - self.arm_idx = random.choices( + self.arm_idx = self._rng.choices( range(self.total_categories), weights=self.sampling_ratio, k=1 )[0] sample = None @@ -243,7 +260,7 @@ def __next__(self): else torch.ones_like(sample["input_ids"][0]) ), "labels": ( - sample["labels"][0] + sample["labels"][0].tolist() if "labels" in sample else sample["input_ids"][0] ), @@ -264,6 +281,16 @@ def load_state_dict(self, state_dict): torch.set_rng_state(state_dict["rng"]) train_dataset_dict_dl_sd = state_dict.pop("train_dataset_dict_dl_sd") random.setstate(state_dict.pop("random_state")) + dataset_rng_state = state_dict.pop("online_mixing_rng_state", None) + saved_seed = state_dict.pop("seed", None) + saved_epoch = state_dict.pop("_current_epoch", None) + if saved_seed is not None: + self.seed = saved_seed + if saved_epoch is not None: + self._current_epoch = saved_epoch + if dataset_rng_state is not None: + self._rng.setstate(dataset_rng_state) + self._rng_state_restored = True for k, v in state_dict.items(): if hasattr(self, k): setattr(self, k, v) @@ -295,6 +322,9 @@ def state_dict(self): "arm_idx": self.arm_idx, "reward_type": str(self.reward_type), "random_state": random.getstate(), + "online_mixing_rng_state": self._rng.getstate(), + "seed": self.seed, + "_current_epoch": self._current_epoch, } def _reset_eval_dataloaders(self): @@ -516,8 +546,9 @@ def update_sampling_weights(self, model, accelerator, state): if accelerator: rewards = accelerator.reduce(rewards, reduction="sum") count = accelerator.reduce(count, reduction="sum") + + self._update_weights(count, rewards) if accelerator and accelerator.is_main_process: - self._update_weights(count, rewards) self.log_to_file( { "current_sampling_weights": self.sampling_weights.tolist(),