diff --git a/megatron/arguments.py b/megatron/arguments.py index 31a8d4000..c5e3faefd 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -375,6 +375,10 @@ def _add_network_size_args(parser): ', needs to be divisible by TP size and `make-vocab-size-divisible-by`.') group.add_argument('--layernorm-epsilon', type=float, default=1e-5, help='Layer norm epsilon.') + group.add_argument('--layernorm-tp-auto-sync', action='store_true', + help='Force syncing layernorm params across TP ranks in forward. ' + 'This is a workaround for an unresolved bug leading to TP ranks ' + 'getting out of sync with each other.') group.add_argument('--apply-residual-connection-post-layernorm', action='store_true', help='If set, use original BERT residula connection ' diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py index 7c3b7a6f4..e95b1b41b 100644 --- a/megatron/data/data_samplers.py +++ b/megatron/data/data_samplers.py @@ -40,7 +40,7 @@ def pack_samples(items, max_seq_len: int, micro_batch_size: int, pad_token: int) 'target_tokens': array([5]) } ] - + Output: decoder_target_tokens = [[6, 7, 8, 3, 4, 5, ]]: Concatenation of tokens followed with padding tokens. decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents. @@ -139,6 +139,7 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None): dataset, batch_sampler=batch_sampler, num_workers=num_workers, + generator=torch.Generator().manual_seed(args.seed), collate_fn=collate_fn, pin_memory=True ) diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 563566b70..ed914d9a0 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -19,15 +19,16 @@ import numbers + +from megatron import get_args +from megatron import mpu from packaging import version -import torch from torch import nn -from torch.nn.parameter import Parameter -import torch.nn.functional as F from torch.nn import init +from torch.nn.parameter import Parameter import importlib - -from megatron import get_args +import torch +import torch.nn.functional as F global fused_mix_prec_layer_norm_cuda fused_mix_prec_layer_norm_cuda = None @@ -83,6 +84,7 @@ def __init__(self, normalized_shape, eps=1e-5): self.reset_parameters() args = get_args() + self.layernorm_tp_auto_sync = args.layernorm_tp_auto_sync self.use_meg_ds_fused_layer_norm = ( args.bf16 # Current Meg-DS cuda kernel has better throughput than torch.nn.LayerNorm @@ -97,6 +99,11 @@ def reset_parameters(self): def forward(self, input): + + if self.layernorm_tp_auto_sync: + torch.distributed.all_reduce(self.weight, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group()) + if self.use_meg_ds_fused_layer_norm: return FusedLayerNormAffineFunction.apply( input, self.weight, self.bias, self.normalized_shape, self.eps) diff --git a/tests/test_training.py b/tests/test_training.py index c77cb9af2..79a43c6a2 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -147,6 +147,7 @@ def get_variation_config(self, variation, output_dir, n_samples=None): --clip-grad 1.0 --weight-decay 1e-1 --embed-layernorm + --layernorm-tp-auto-sync --fp16 --log-level debug