diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 1d072958..eadef50f 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -72,3 +72,11 @@ class train_config: stage2_prompt_length: int = 64 stage2_batch_size: int = 96 stage2_seq_length: int = 256 + + # FIM training + fim_training: bool = False + psm_rate: float = 0.0 + spm_rate: float = 0.0 + fim_pre: int = 1 + fim_mid: int = 2 + fim_suf: int = 3 diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 4b811d6d..d4bbc984 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -5,6 +5,7 @@ AutoHandler, BufferDataset, CheckpointDataset, + FIMDataset, ParquetHandler, PreloadBufferDataset, PreprocessDataset, @@ -57,9 +58,9 @@ def __iter__(self): return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size) -def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): +def get_data_loader(cfg, rank, world_size): """ - Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training. + Pytorch dataloader for stateful, distributed, and rescalable language model training. Assumes underlying data is sequences of integer values. ... Args @@ -70,11 +71,11 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): Rank of current distributed worker. Used for handling dataset sharding logic. world_size : int Number of distributed workers. Used for handling dataset sharding logic. - postprocess : List[Callable] - Any task-specific postprocessing to apply before handing over data. Steps will apply in - the order provided by the user. For CLM training, use postprocess=[causal_lm]. """ + if cfg.fim_training: + assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?" + datasets, weights = parse_data_args(cfg.datasets, cfg.weights) # Base streaming dataset. Returns doc chunks in sequence. @@ -118,9 +119,10 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): verbose=(rank == 0), ) # Wrap above dataset in packing logic to form constant-length lines. + # Increment seq len to counteract CLM's one token removal. data = BufferDataset( data, - cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1, + cfg.seq_length + 1, bos_token=cfg.bol_token, eos_token=cfg.eol_token, pack_hard=True, @@ -128,10 +130,23 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average. data = PreloadBufferDataset(data, 10000) - # Apply desired postprocessing steps in sequence + # Apply FIM transformation if needed + if cfg.fim_training: + data = FIMDataset( + data, + cfg.eos_token, + cfg.psm_rate, + cfg.spm_rate, + pre_token=cfg.fim_pre, + mid_token=cfg.fim_mid, + suf_token=cfg.fim_suf, + ) + + # Transform to tensors data = PreprocessDataset(data, torch.IntTensor) - for p in postprocess: - data = PreprocessDataset(data, p) + + # Apply CLM transformation + data = PreprocessDataset(data, causal_lm) # Enable auto-saving data = CheckpointDataset( diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index aedc5862..7d634161 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -696,6 +696,128 @@ def load_state_dict(self, state_dicts, sharded_input=False): return sharded_dicts +class FIMDataset(_WrapperDataset): + """ + Wrapper for a StatefulDataset that implements Fill-In-the-Middle training + (https://arxiv.org/pdf/2207.14255). + Input should be a packed sequence (i.e. call BufferDataset before FIMDataset). + Breaks sequence apart into component document spans, and for each document span + of sufficient length, transforms with specified probability into: + PSM mode:
(prefix)(suffix) (middle) + SPM mode: (suffix) (prefix) (middle) + The new delimiter tokens can be omitted by passing in None. + Any extra tokens after transformation are dropped from the end of the sequence. + ... + Args + ---- + dataset : _StatefulDataset + Fully instantiated dataset + delimiter_token : any + Token used to indicate document boundaries + psm_rate : float + Chance to transform into PSM. Cannot exceed 1. + spm_rate : float + Chance to transform into SPM. Cannot exceed 1. + min_len : int + Minimum document length to perform FIM transformation + pre_token : any | none + Token used to indicate prefix section of the document + mid_token : any | none + Token used to indicate middle infill section of the document + suf_token : any | none + Token used to indicate suffix section of the document + """ + + def __init__( + self, + dataset: _StatefulDataset, + delimiter_token: Any, + psm_rate: float = 0.0, + spm_rate: float = 0.0, + min_len: int = 10, + pre_token=None, + mid_token=None, + suf_token=None, + ): + super().__init__(dataset) + assert ( + psm_rate + spm_rate > 0 + ), f"FIM training requires SPM or PSM transformation. Please specify a nonzero psm_rate or spm_rate." + assert ( + psm_rate + spm_rate <= 1 + ), f"Combined psm_rate {psm_rate} and spm_rate {spm_rate} probabilities cannot exceed 1." + self.psm = psm_rate + self.spm = spm_rate + self.delimiter = delimiter_token + self.min_len = min_len + self.pref = pre_token + self.suff = suf_token + self.midd = mid_token + + self.g_state = None + self.generator = torch.Generator().manual_seed(self.rank) + self.state_params = ["g_state"] + + def __iter__(self): + dataset = iter(self.dataset) + while True: + inp = next(dataset) + len_ = len(inp) + i_eos = [0] + [i for i, x in enumerate(inp) if x == self.delimiter] + [len_] + docs = [ + inp[i_eos[j] + 1 : i_eos[j + 1]] for j in range(len(i_eos) - 1) + ] # list[list[any]] + out = [] + for i in range(len(docs)): + doc = docs[i] + if len(docs[i]) >= self.min_len: + # decide psm, spm, or nothing + thresh = torch.rand([1], generator=self.generator).item() + if thresh < self.psm + self.spm: + # Split doc + doc = [] + if self.pref: + doc = [self.pref] + splits = torch.randint( + 0, len(docs[i]), [2], generator=self.generator + ).tolist() + pre = docs[i][: min(splits)] + mid = docs[i][min(splits) : max(splits)] + suf = docs[i][max(splits) :] + + if thresh < self.psm: + # PSM transformation + doc += pre + if self.suff: + doc.append(self.suff) + doc += suf + if self.midd: + doc.append(self.midd) + doc += mid + else: + # SPM transformation + if self.suff: + doc.append(self.suff) + doc += suf + if self.midd: + doc.append(self.midd) + doc += pre + mid + out += doc + [self.delimiter] + yield out[:len_] + + def state_dict(self): + # Write generator state manually + self.g_state = self.generator.get_state() + return super().state_dict() + + def load_state_dict(self, state_dicts, sharded_input=False): + sharded_dicts = super().load_state_dict(state_dicts, sharded_input) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(self.g_state) + return sharded_dicts + + class BufferDataset(_WrapperDataset): """ Wrapper for a _StatefulDataset that takes in sequences of varying lengths, and packs/pads them @@ -872,9 +994,8 @@ def __init__( self.bos = bos_token self.drop = strip_tokens self.verbose = verbose - self.docset: List[ - Any - ] = [] # map of doc indices to (shardid, min docid, max docid) + # Map of doc indices to (shardid, min docid, max docid) + self.docset: List[Any] = [] # Position self.docset_index = 0 diff --git a/main_training_llama.py b/main_training_llama.py index 67cccee2..a7e1020f 100644 --- a/main_training_llama.py +++ b/main_training_llama.py @@ -122,9 +122,11 @@ def main(**kwargs): model, optimizer, None, - path=os.path.join(cfg.ckpt_load_path, "checkpoints/") - if not os.path.isfile(cfg.ckpt_load_path) - else cfg.ckpt_load_path, + path=( + os.path.join(cfg.ckpt_load_path, "checkpoints/") + if not os.path.isfile(cfg.ckpt_load_path) + else cfg.ckpt_load_path + ), strict=False, ) if not is_resuming: diff --git a/main_training_mamba.py b/main_training_mamba.py index 3619ea25..68a3c830 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -119,9 +119,11 @@ def main(**kwargs): model, optimizer, None, - path=os.path.join(cfg.ckpt_load_path, "checkpoints/") - if not os.path.isfile(cfg.ckpt_load_path) - else cfg.ckpt_load_path, + path=( + os.path.join(cfg.ckpt_load_path, "checkpoints/") + if not os.path.isfile(cfg.ckpt_load_path) + else cfg.ckpt_load_path + ), strict=False, ) if not is_resuming: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 83b2426b..40bef481 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -632,6 +632,10 @@ def test_multi_reload_stress(): # preload / sample / scale / doc pipeline multi_reload_stress_check(lambda: d6(d5(d4()))) + # Add FIM dataset + d7 = lambda x: [FIMDataset(d, -1, 0.25, 0.25, 10, -2, -3, -4) for d in x] + multi_reload_stress_check(lambda: d7(d6(d5(d4())))) + # SCALABLEDATASET TESTS