Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions fms_fsdp/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,11 @@ class train_config:
stage2_prompt_length: int = 64
stage2_batch_size: int = 96
stage2_seq_length: int = 256

# FIM training
fim_training: bool = False
psm_rate: float = 0.0
spm_rate: float = 0.0
fim_pre: int = 1
fim_mid: int = 2
fim_suf: int = 3
33 changes: 24 additions & 9 deletions fms_fsdp/utils/dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
AutoHandler,
BufferDataset,
CheckpointDataset,
FIMDataset,
ParquetHandler,
PreloadBufferDataset,
PreprocessDataset,
Expand Down Expand Up @@ -57,9 +58,9 @@ def __iter__(self):
return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size)


def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
def get_data_loader(cfg, rank, world_size):
"""
Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training.
Pytorch dataloader for stateful, distributed, and rescalable language model training.
Assumes underlying data is sequences of integer values.
...
Args
Expand All @@ -70,11 +71,11 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
Rank of current distributed worker. Used for handling dataset sharding logic.
world_size : int
Number of distributed workers. Used for handling dataset sharding logic.
postprocess : List[Callable]
Any task-specific postprocessing to apply before handing over data. Steps will apply in
the order provided by the user. For CLM training, use postprocess=[causal_lm].
"""

if cfg.fim_training:
assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?"

datasets, weights = parse_data_args(cfg.datasets, cfg.weights)

# Base streaming dataset. Returns doc chunks in sequence.
Expand Down Expand Up @@ -118,20 +119,34 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
verbose=(rank == 0),
)
# Wrap above dataset in packing logic to form constant-length lines.
# Increment seq len to counteract CLM's one token removal.
data = BufferDataset(
data,
cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1,
cfg.seq_length + 1,
bos_token=cfg.bol_token,
eos_token=cfg.eol_token,
pack_hard=True,
)
# Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average.
data = PreloadBufferDataset(data, 10000)

# Apply desired postprocessing steps in sequence
# Apply FIM transformation if needed
if cfg.fim_training:
data = FIMDataset(
data,
cfg.eos_token,
cfg.psm_rate,
cfg.spm_rate,
pre_token=cfg.fim_pre,
mid_token=cfg.fim_mid,
suf_token=cfg.fim_suf,
)

# Transform to tensors
data = PreprocessDataset(data, torch.IntTensor)
for p in postprocess:
data = PreprocessDataset(data, p)

# Apply CLM transformation
data = PreprocessDataset(data, causal_lm)

# Enable auto-saving
data = CheckpointDataset(
Expand Down
127 changes: 124 additions & 3 deletions fms_fsdp/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,128 @@ def load_state_dict(self, state_dicts, sharded_input=False):
return sharded_dicts


class FIMDataset(_WrapperDataset):
"""
Wrapper for a StatefulDataset that implements Fill-In-the-Middle training
(https://arxiv.org/pdf/2207.14255).
Input should be a packed sequence (i.e. call BufferDataset before FIMDataset).
Breaks sequence apart into component document spans, and for each document span
of sufficient length, transforms with specified probability into:
PSM mode: <PRE> (prefix) <SUF> (suffix) <MID> (middle) <EOS>
SPM mode: <PRE> <SUF> (suffix) <MID> (prefix) (middle) <EOS>
The new delimiter tokens can be omitted by passing in None.
Any extra tokens after transformation are dropped from the end of the sequence.
...
Args
----
dataset : _StatefulDataset
Fully instantiated dataset
delimiter_token : any
Token used to indicate document boundaries
psm_rate : float
Chance to transform into PSM. Cannot exceed 1.
spm_rate : float
Chance to transform into SPM. Cannot exceed 1.
min_len : int
Minimum document length to perform FIM transformation
pre_token : any | none
Token used to indicate prefix section of the document
mid_token : any | none
Token used to indicate middle infill section of the document
suf_token : any | none
Token used to indicate suffix section of the document
"""

def __init__(
self,
dataset: _StatefulDataset,
delimiter_token: Any,
psm_rate: float = 0.0,
spm_rate: float = 0.0,
min_len: int = 10,
pre_token=None,
mid_token=None,
suf_token=None,
):
super().__init__(dataset)
assert (
psm_rate + spm_rate > 0
), f"FIM training requires SPM or PSM transformation. Please specify a nonzero psm_rate or spm_rate."
assert (
psm_rate + spm_rate <= 1
), f"Combined psm_rate {psm_rate} and spm_rate {spm_rate} probabilities cannot exceed 1."
self.psm = psm_rate
self.spm = spm_rate
self.delimiter = delimiter_token
self.min_len = min_len
self.pref = pre_token
self.suff = suf_token
self.midd = mid_token

self.g_state = None
self.generator = torch.Generator().manual_seed(self.rank)
self.state_params = ["g_state"]

def __iter__(self):
dataset = iter(self.dataset)
while True:
inp = next(dataset)
len_ = len(inp)
i_eos = [0] + [i for i, x in enumerate(inp) if x == self.delimiter] + [len_]
docs = [
inp[i_eos[j] + 1 : i_eos[j + 1]] for j in range(len(i_eos) - 1)
] # list[list[any]]
out = []
for i in range(len(docs)):
doc = docs[i]
if len(docs[i]) >= self.min_len:
# decide psm, spm, or nothing
thresh = torch.rand([1], generator=self.generator).item()
if thresh < self.psm + self.spm:
# Split doc
doc = []
if self.pref:
doc = [self.pref]
splits = torch.randint(
0, len(docs[i]), [2], generator=self.generator
).tolist()
pre = docs[i][: min(splits)]
mid = docs[i][min(splits) : max(splits)]
suf = docs[i][max(splits) :]

if thresh < self.psm:
# PSM transformation
doc += pre
if self.suff:
doc.append(self.suff)
doc += suf
if self.midd:
doc.append(self.midd)
doc += mid
else:
# SPM transformation
if self.suff:
doc.append(self.suff)
doc += suf
if self.midd:
doc.append(self.midd)
doc += pre + mid
out += doc + [self.delimiter]
yield out[:len_]

def state_dict(self):
# Write generator state manually
self.g_state = self.generator.get_state()
return super().state_dict()

def load_state_dict(self, state_dicts, sharded_input=False):
sharded_dicts = super().load_state_dict(state_dicts, sharded_input)
# Manually set generator state if it exists
if self.g_state is not None:
self.generator.set_state(self.g_state)
return sharded_dicts


class BufferDataset(_WrapperDataset):
"""
Wrapper for a _StatefulDataset that takes in sequences of varying lengths, and packs/pads them
Expand Down Expand Up @@ -872,9 +994,8 @@ def __init__(
self.bos = bos_token
self.drop = strip_tokens
self.verbose = verbose
self.docset: List[
Any
] = [] # map of doc indices to (shardid, min docid, max docid)
# Map of doc indices to (shardid, min docid, max docid)
self.docset: List[Any] = []

# Position
self.docset_index = 0
Expand Down
8 changes: 5 additions & 3 deletions main_training_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,11 @@ def main(**kwargs):
model,
optimizer,
None,
path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
if not os.path.isfile(cfg.ckpt_load_path)
else cfg.ckpt_load_path,
path=(
os.path.join(cfg.ckpt_load_path, "checkpoints/")
if not os.path.isfile(cfg.ckpt_load_path)
else cfg.ckpt_load_path
),
strict=False,
)
if not is_resuming:
Expand Down
8 changes: 5 additions & 3 deletions main_training_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,11 @@ def main(**kwargs):
model,
optimizer,
None,
path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
if not os.path.isfile(cfg.ckpt_load_path)
else cfg.ckpt_load_path,
path=(
os.path.join(cfg.ckpt_load_path, "checkpoints/")
if not os.path.isfile(cfg.ckpt_load_path)
else cfg.ckpt_load_path
),
strict=False,
)
if not is_resuming:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,10 @@ def test_multi_reload_stress():
# preload / sample / scale / doc pipeline
multi_reload_stress_check(lambda: d6(d5(d4())))

# Add FIM dataset
d7 = lambda x: [FIMDataset(d, -1, 0.25, 0.25, 10, -2, -3, -4) for d in x]
multi_reload_stress_check(lambda: d7(d6(d5(d4()))))


# SCALABLEDATASET TESTS

Expand Down
Loading