diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 5fa56793..22b9a840 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -29,6 +29,7 @@ class train_config: logical_shards: int = 1024 num_workers: int = 1 doc_cutoff: int = 1_000_000 + doc_breakpoint: int = 65_536 # fsdp policies sharding_strategy: str = "hsdp" diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 022ca4b5..09f2c1ec 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -1,4 +1,5 @@ import torch +from math import ceil from fms_fsdp.utils.dataset_utils import ( ArrowHandler, @@ -94,7 +95,7 @@ def get_data_loader(cfg, rank, world_size): ) else: filehandler = _handler_map[cfg.file_type](cols) - + # Base reader layer data = StreamingDocDataset( cfg.data_path, @@ -105,6 +106,7 @@ def get_data_loader(cfg, rank, world_size): bos_token=cfg.bos_token, strip_tokens=set(droplist), min_length=3, + max_consecutive_chunks=ceil(cfg.doc_breakpoint/1024), seed=cfg.seed, ) # Add rescaling/resharding