From 68aadbe4ea42305a36434d94693551b7ce0b0c3f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 10 Jan 2025 16:23:48 -0500 Subject: [PATCH 1/5] Add fim Signed-off-by: Davis Wertheimer --- fms_fsdp/config/training.py | 8 ++ fms_fsdp/utils/dataloader_utils.py | 34 ++++++--- fms_fsdp/utils/dataset_utils.py | 118 +++++++++++++++++++++++++++++ tests/test_datasets.py | 4 + 4 files changed, 155 insertions(+), 9 deletions(-) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 1d072958..eadef50f 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -72,3 +72,11 @@ class train_config: stage2_prompt_length: int = 64 stage2_batch_size: int = 96 stage2_seq_length: int = 256 + + # FIM training + fim_training: bool = False + psm_rate: float = 0.0 + spm_rate: float = 0.0 + fim_pre: int = 1 + fim_mid: int = 2 + fim_suf: int = 3 diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 4b811d6d..a3f0703a 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -5,6 +5,7 @@ AutoHandler, BufferDataset, CheckpointDataset, + FIMDataset, ParquetHandler, PreloadBufferDataset, PreprocessDataset, @@ -57,9 +58,9 @@ def __iter__(self): return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size) -def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): +def get_data_loader(cfg, rank, world_size): """ - Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training. + Pytorch dataloader for stateful, distributed, and rescalable language model training. Assumes underlying data is sequences of integer values. ... Args @@ -70,11 +71,11 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): Rank of current distributed worker. Used for handling dataset sharding logic. world_size : int Number of distributed workers. Used for handling dataset sharding logic. - postprocess : List[Callable] - Any task-specific postprocessing to apply before handing over data. Steps will apply in - the order provided by the user. For CLM training, use postprocess=[causal_lm]. """ + if cfg.fim_training: + assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?" + datasets, weights = parse_data_args(cfg.datasets, cfg.weights) # Base streaming dataset. Returns doc chunks in sequence. @@ -118,9 +119,10 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): verbose=(rank == 0), ) # Wrap above dataset in packing logic to form constant-length lines. + # CLM removes 1 token, FIM adds at least 3. data = BufferDataset( data, - cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1, + cfg.seq_length - 3 if cfg.fim_training else cfg.seq_length + 1, bos_token=cfg.bol_token, eos_token=cfg.eol_token, pack_hard=True, @@ -128,10 +130,24 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average. data = PreloadBufferDataset(data, 10000) - # Apply desired postprocessing steps in sequence + # Apply FIM transformation if needed + if cfg.fim_training: + data = FIMDataset( + data, + cfg.eos_token, + cfg.psm_rate, + cfg.spm_rate, + pre_token=cfg.fim_pre, + mid_token=cfg.fim_mid, + suf_token=cfg.fim_suf, + ) + + # Transform to tensors data = PreprocessDataset(data, torch.IntTensor) - for p in postprocess: - data = PreprocessDataset(data, p) + + # Apply CLM transformation if needed + if not cfg.fim_training: + data = PreprocessDataset(data, causal_lm) # Enable auto-saving data = CheckpointDataset( diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index aedc5862..4a72e6c2 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -694,6 +694,124 @@ def load_state_dict(self, state_dicts, sharded_input=False): # Manually set buffer size self.buffer_size = len(self.buffer) return sharded_dicts + + +class FIMDataset(_WrapperDataset): + """ + Wrapper for a StatefulDataset that implements Fill-In-the-Middle training + (https://arxiv.org/pdf/2207.14255). + Input should be a packed sequence (i.e. call BufferDataset before FIMDataset). + Breaks sequence apart into component document spans, and for each document span + of sufficient length, transforms with specified probability into: + PSM mode:
 (prefix)  (suffix)  (middle) 
+    SPM mode: 
  (suffix)  (prefix) (middle) 
+    The new delimiter tokens can be omitted by passing in None.
+    Any extra tokens after transformation are dropped from the end of the sequence.
+    ...
+    Args
+    ----
+    dataset : _StatefulDataset
+        Fully instantiated dataset
+    delimiter_token : any
+        Token used to indicate document boundaries
+    psm_rate : float
+        Chance to transform into PSM. Cannot exceed 1.
+    spm_rate : float
+        Chance to transform into SPM. Cannot exceed 1.
+    min_len : int
+        Minimum document length to perform FIM transformation
+    pre_token : any | none
+        Token used to indicate prefix section of the document
+    mid_token : any | none
+        Token used to indicate middle infill section of the document
+    suf_token : any | none
+        Token used to indicate suffix section of the document
+    """
+
+    def __init__(
+        self, 
+        dataset: _StatefulDataset, 
+        delimiter_token: Any,
+        psm_rate: float = 0.0,
+        spm_rate: float = 0.0,
+        min_len: int = 10,
+        pre_token = None,
+        mid_token = None,
+        suf_token = None,
+    ):
+        super().__init__(dataset)
+        assert (
+            psm_rate + spm_rate > 0
+        ), f"FIM training requires SPM or PSM transformation. Please specify a nonzero psm_rate or spm_rate."
+        assert (
+            psm_rate + spm_rate <= 1
+        ), f"Combined psm_rate {psm_rate} and spm_rate {spm_rate} probabilities cannot exceed 1."
+        self.psm = psm_rate
+        self.spm = spm_rate
+        self.delimiter = delimiter_token
+        self.min_len = min_len
+        self.pref = pre_token
+        self.suff = suf_token
+        self.midd = mid_token
+
+        self.g_state = None
+        self.generator = torch.Generator().manual_seed(self.rank)
+        self.state_params = ["g_state"]
+        
+    def __iter__(self):
+        dataset = iter(self.dataset)
+        while True:
+            inp = next(dataset)
+            len_ = len(inp)
+            i_eos = [0] + [i for i,x in enumerate(inp) if x==self.delimiter] + [len_]
+            docs = [inp[i_eos[j]+1:i_eos[j+1]] for j in range(len(i_eos)-1)]  # list[list[any]]
+            out = []
+            for i in range(len(docs)):
+                doc = docs[i]
+                if len(docs[i]) >= self.min_len:
+                    # decide psm, spm, or nothing
+                    thresh = torch.rand([1], generator=self.generator).item()
+                    if thresh < self.psm + self.spm:
+                        # Split doc
+                        doc = []
+                        if self.pref:
+                            doc = [self.pref]
+                        splits = torch.randint(0, len(docs[i]), [2], generator=self.generator).tolist()
+                        pre = docs[i][:min(splits)]
+                        mid = docs[i][min(splits):max(splits)]
+                        suf = docs[i][max(splits):]
+
+                        if thresh < self.psm:
+                            # PSM transformation
+                            doc += pre
+                            if self.suff:
+                                doc.append(self.suff)
+                            doc += suf
+                            if self.midd:
+                                doc.append(self.midd)
+                            doc += mid
+                        else:
+                            # SPM transformation
+                            if self.suff:
+                                doc.append(self.suff)
+                            doc += suf
+                            if self.midd:
+                                doc.append(self.midd)
+                            doc += pre + mid
+                out += doc + [self.delimiter]
+            yield out[:len_]
+
+    def state_dict(self):
+        # Write generator state manually
+        self.g_state = self.generator.get_state()
+        return super().state_dict()
+
+    def load_state_dict(self, state_dicts, sharded_input=False):
+        sharded_dicts = super().load_state_dict(state_dicts, sharded_input)
+        # Manually set generator state if it exists
+        if self.g_state is not None:
+            self.generator.set_state(self.g_state)
+        return sharded_dicts
 
 
 class BufferDataset(_WrapperDataset):
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 83b2426b..e78febad 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -632,6 +632,10 @@ def test_multi_reload_stress():
     # preload / sample / scale / doc pipeline
     multi_reload_stress_check(lambda: d6(d5(d4())))
 
+    # Add FIM dataset
+    d7 = lambda x: [FIMDataset(d, -1, .25, .25, 10, -2, -3, -4) for d in x]
+    multi_reload_stress_check(lambda: d7(d6(d5(d4()))))
+
 
 # SCALABLEDATASET TESTS
 

From ac025e21bbd5e53ca4c070ae51a1e4362f1732b6 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 10 Jan 2025 16:24:12 -0500
Subject: [PATCH 2/5] Blacking

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/config/training.py     |  4 +++-
 fms_fsdp/utils/dataset_utils.py | 38 ++++++++++++++++++---------------
 main_training_llama.py          |  8 ++++---
 main_training_mamba.py          |  8 ++++---
 tests/test_datasets.py          |  2 +-
 5 files changed, 35 insertions(+), 25 deletions(-)

diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py
index eadef50f..f985ed0d 100644
--- a/fms_fsdp/config/training.py
+++ b/fms_fsdp/config/training.py
@@ -15,7 +15,9 @@ class train_config:
     file_type: str = "arrow"
     col_name: str = "tokens"
     tokenizer_path: str = "/fsx/tokenizer"
-    datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
+    datasets: str = (
+        "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
+    )
     weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100"
     seq_length: int = 4096
     vocab_size: int = 32000
diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 4a72e6c2..e6480dab 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -694,14 +694,14 @@ def load_state_dict(self, state_dicts, sharded_input=False):
         # Manually set buffer size
         self.buffer_size = len(self.buffer)
         return sharded_dicts
-    
+
 
 class FIMDataset(_WrapperDataset):
     """
     Wrapper for a StatefulDataset that implements Fill-In-the-Middle training
     (https://arxiv.org/pdf/2207.14255).
     Input should be a packed sequence (i.e. call BufferDataset before FIMDataset).
-    Breaks sequence apart into component document spans, and for each document span 
+    Breaks sequence apart into component document spans, and for each document span
     of sufficient length, transforms with specified probability into:
     PSM mode: 
 (prefix)  (suffix)  (middle) 
     SPM mode: 
  (suffix)  (prefix) (middle) 
@@ -729,15 +729,15 @@ class FIMDataset(_WrapperDataset):
     """
 
     def __init__(
-        self, 
-        dataset: _StatefulDataset, 
+        self,
+        dataset: _StatefulDataset,
         delimiter_token: Any,
         psm_rate: float = 0.0,
         spm_rate: float = 0.0,
         min_len: int = 10,
-        pre_token = None,
-        mid_token = None,
-        suf_token = None,
+        pre_token=None,
+        mid_token=None,
+        suf_token=None,
     ):
         super().__init__(dataset)
         assert (
@@ -757,14 +757,16 @@ def __init__(
         self.g_state = None
         self.generator = torch.Generator().manual_seed(self.rank)
         self.state_params = ["g_state"]
-        
+
     def __iter__(self):
         dataset = iter(self.dataset)
         while True:
             inp = next(dataset)
             len_ = len(inp)
-            i_eos = [0] + [i for i,x in enumerate(inp) if x==self.delimiter] + [len_]
-            docs = [inp[i_eos[j]+1:i_eos[j+1]] for j in range(len(i_eos)-1)]  # list[list[any]]
+            i_eos = [0] + [i for i, x in enumerate(inp) if x == self.delimiter] + [len_]
+            docs = [
+                inp[i_eos[j] + 1 : i_eos[j + 1]] for j in range(len(i_eos) - 1)
+            ]  # list[list[any]]
             out = []
             for i in range(len(docs)):
                 doc = docs[i]
@@ -776,10 +778,12 @@ def __iter__(self):
                         doc = []
                         if self.pref:
                             doc = [self.pref]
-                        splits = torch.randint(0, len(docs[i]), [2], generator=self.generator).tolist()
-                        pre = docs[i][:min(splits)]
-                        mid = docs[i][min(splits):max(splits)]
-                        suf = docs[i][max(splits):]
+                        splits = torch.randint(
+                            0, len(docs[i]), [2], generator=self.generator
+                        ).tolist()
+                        pre = docs[i][: min(splits)]
+                        mid = docs[i][min(splits) : max(splits)]
+                        suf = docs[i][max(splits) :]
 
                         if thresh < self.psm:
                             # PSM transformation
@@ -990,9 +994,9 @@ def __init__(
         self.bos = bos_token
         self.drop = strip_tokens
         self.verbose = verbose
-        self.docset: List[
-            Any
-        ] = []  # map of doc indices to (shardid, min docid, max docid)
+        self.docset: List[Any] = (
+            []
+        )  # map of doc indices to (shardid, min docid, max docid)
 
         # Position
         self.docset_index = 0
diff --git a/main_training_llama.py b/main_training_llama.py
index 67cccee2..a7e1020f 100644
--- a/main_training_llama.py
+++ b/main_training_llama.py
@@ -122,9 +122,11 @@ def main(**kwargs):
         model,
         optimizer,
         None,
-        path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
-        if not os.path.isfile(cfg.ckpt_load_path)
-        else cfg.ckpt_load_path,
+        path=(
+            os.path.join(cfg.ckpt_load_path, "checkpoints/")
+            if not os.path.isfile(cfg.ckpt_load_path)
+            else cfg.ckpt_load_path
+        ),
         strict=False,
     )
     if not is_resuming:
diff --git a/main_training_mamba.py b/main_training_mamba.py
index 3619ea25..68a3c830 100644
--- a/main_training_mamba.py
+++ b/main_training_mamba.py
@@ -119,9 +119,11 @@ def main(**kwargs):
         model,
         optimizer,
         None,
-        path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
-        if not os.path.isfile(cfg.ckpt_load_path)
-        else cfg.ckpt_load_path,
+        path=(
+            os.path.join(cfg.ckpt_load_path, "checkpoints/")
+            if not os.path.isfile(cfg.ckpt_load_path)
+            else cfg.ckpt_load_path
+        ),
         strict=False,
     )
     if not is_resuming:
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index e78febad..40bef481 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -633,7 +633,7 @@ def test_multi_reload_stress():
     multi_reload_stress_check(lambda: d6(d5(d4())))
 
     # Add FIM dataset
-    d7 = lambda x: [FIMDataset(d, -1, .25, .25, 10, -2, -3, -4) for d in x]
+    d7 = lambda x: [FIMDataset(d, -1, 0.25, 0.25, 10, -2, -3, -4) for d in x]
     multi_reload_stress_check(lambda: d7(d6(d5(d4()))))
 
 

From 93230dd27953b960bcc6af5ec8684ddc36fba0f1 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 10 Jan 2025 16:27:14 -0500
Subject: [PATCH 3/5] Corrected fim/clm combo

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/utils/dataloader_utils.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py
index a3f0703a..d4bbc984 100644
--- a/fms_fsdp/utils/dataloader_utils.py
+++ b/fms_fsdp/utils/dataloader_utils.py
@@ -119,10 +119,10 @@ def get_data_loader(cfg, rank, world_size):
         verbose=(rank == 0),
     )
     # Wrap above dataset in packing logic to form constant-length lines.
-    # CLM removes 1 token, FIM adds at least 3.
+    # Increment seq len to counteract CLM's one token removal.
     data = BufferDataset(
         data,
-        cfg.seq_length - 3 if cfg.fim_training else cfg.seq_length + 1,
+        cfg.seq_length + 1,
         bos_token=cfg.bol_token,
         eos_token=cfg.eol_token,
         pack_hard=True,
@@ -145,9 +145,8 @@ def get_data_loader(cfg, rank, world_size):
     # Transform to tensors
     data = PreprocessDataset(data, torch.IntTensor)
 
-    # Apply CLM transformation if needed
-    if not cfg.fim_training:
-        data = PreprocessDataset(data, causal_lm)
+    # Apply CLM transformation
+    data = PreprocessDataset(data, causal_lm)
 
     # Enable auto-saving
     data = CheckpointDataset(

From ec0cc5d3852e43ddff239684cb04d5d430251386 Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 10 Jan 2025 16:34:57 -0500
Subject: [PATCH 4/5] reblacking

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/config/training.py     | 4 +---
 fms_fsdp/utils/dataset_utils.py | 5 ++---
 2 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py
index f985ed0d..eadef50f 100644
--- a/fms_fsdp/config/training.py
+++ b/fms_fsdp/config/training.py
@@ -15,9 +15,7 @@ class train_config:
     file_type: str = "arrow"
     col_name: str = "tokens"
     tokenizer_path: str = "/fsx/tokenizer"
-    datasets: str = (
-        "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
-    )
+    datasets: str = "lang=en/dataset=commoncrawl,lang=en/dataset=webhose,lang=en/dataset=github_clean,lang=de/dataset=wikipedia,lang=es/dataset=wikipedia,lang=fr/dataset=wikipedia,lang=ja/dataset=wikipedia,lang=pt/dataset=wikipedia,lang=en/dataset=wikimedia,lang=en/dataset=uspto,lang=en/dataset=pubmedcentral,lang=en/dataset=arxiv,lang=en/dataset=stackexchange"
     weights: str = "7725,500,550,28,17,22,25,8,100,500,175,250,100"
     seq_length: int = 4096
     vocab_size: int = 32000
diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index e6480dab..9803ce94 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -994,9 +994,8 @@ def __init__(
         self.bos = bos_token
         self.drop = strip_tokens
         self.verbose = verbose
-        self.docset: List[Any] = (
-            []
-        )  # map of doc indices to (shardid, min docid, max docid)
+        # Map of doc indices to (shardid, min docid, max docid)
+        self.docset: List[Any] = []  
 
         # Position
         self.docset_index = 0

From 88db9444e5aa8addc1279a90039db009513f917b Mon Sep 17 00:00:00 2001
From: Davis Wertheimer 
Date: Fri, 10 Jan 2025 16:36:54 -0500
Subject: [PATCH 5/5] Rereblacking

Signed-off-by: Davis Wertheimer 
---
 fms_fsdp/utils/dataset_utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 9803ce94..7d634161 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -995,7 +995,7 @@ def __init__(
         self.drop = strip_tokens
         self.verbose = verbose
         # Map of doc indices to (shardid, min docid, max docid)
-        self.docset: List[Any] = []  
+        self.docset: List[Any] = []
 
         # Position
         self.docset_index = 0