diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py index d08e9b7bf..3419fe6fb 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py @@ -8,6 +8,7 @@ import jax import jax.numpy as jnp +import jmp from flax import linen as nn @@ -26,18 +27,24 @@ class ModelConfig: use_residual_scaling: bool = True tie_embeddings: bool = True # Whether to tie input and output embed qknorm_epsilon: float = 1e-6 - - dtype: jnp.dtype = jnp.float32 attention_init: nn.initializers.Initializer = nn.initializers.normal( stddev=0.02 ) linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + param_dtype: jnp.dtype = jnp.float32 + compute_dtype: jnp.dtype = jnp.bfloat16 + output_dtype: jnp.dtype = jnp.bfloat16 def __post_init__(self): self.residual_init = nn.initializers.normal( stddev=0.02 / jnp.sqrt(2 * self.num_layers) ) + self.mp_policy = jmp.Policy( + compute_dtype=self.compute_dtype, + param_dtype=self.param_dtype, + output_dtype=self.output_dtype, + ) class Mlp(nn.Module): @@ -49,7 +56,11 @@ class Mlp(nn.Module): def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg linear = partial( - nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype + nn.Dense, + kernel_init=cfg.linear_init, + use_bias=False, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, ) # Adjust hidden dimension to keep the number of parameters invariant to # the activation function used since the GLU MLP has 3 * hidden_dim * D @@ -65,7 +76,8 @@ def __call__(self, x_BxLxD: jax.Array): x_BxLxD = nn.Dense( cfg.model_dim, use_bias=False, - dtype=cfg.dtype, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init, @@ -96,7 +108,7 @@ def apply_rope(q, k, freqs_cis): def rotate_tensor(x): # Split into real and imaginary parts - x_r2 = x.reshape(*x.shape[:-1], -1, 2) + x_r2 = x.reshape(*x.shape[:-1], -1, 2).astype(jnp.float32) L = x.shape[1] freqs = freqs_cis[:, :L, :, :, :] @@ -109,7 +121,7 @@ def rotate_tensor(x): axis=-1, ) - return rotated_x_r2.reshape(*x.shape) + return rotated_x_r2.reshape(*x.shape).astype(x.dtype) # Apply rotation to Q and K separately rotated_q = rotate_tensor(q) @@ -141,7 +153,8 @@ def setup(self): features=(cfg.num_heads, self.Dh), kernel_init=cfg.attention_init, use_bias=False, - dtype=cfg.dtype, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, ) self.multilinear_query = self.multilinear(name='query') self.multilinear_key = self.multilinear(name='key') @@ -150,7 +163,9 @@ def setup(self): seq_len = cfg.seq_len attn_scale0 = jnp.log2(seq_len**2 - seq_len) self.attn_scale = self.param( - 'attn_scale', nn.initializers.constant(attn_scale0), () + 'attn_scale', + nn.initializers.constant(attn_scale0, dtype=cfg.compute_dtype), + (), ) self.output_projection = nn.DenseGeneral( features=cfg.model_dim, @@ -160,7 +175,8 @@ def setup(self): if cfg.use_residual_scaling else cfg.linear_init, use_bias=False, - dtype=cfg.dtype, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, ) def __call__(self, x_BxLxD: jax.Array): @@ -177,32 +193,17 @@ def __call__(self, x_BxLxD: jax.Array): # Apply QK normalization q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps - - # Compute attention scores - att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) - - # Causal attention mask - L = x_BxLxD.shape[1] - mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) - - # Apply mask and softmax - _NEG_INF = jnp.finfo(cfg.dtype).min - att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) - att_BxHxLxL = ( - self.attn_scale * att_BxHxLxL - ) # Learned scaling factor for QK norm - att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) - att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) - - # Compute attention output - out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh) - - # Reshape and project output + q_BxLxHxDh *= self.attn_scale + out_BxLxHxDh = jax.nn.dot_product_attention( + query=q_BxLxHxDh, + key=k_BxLxHxDh, + value=v_BxLxHxDh, + is_causal=True, + scale=1.0, + implementation='cudnn' if cfg.compute_dtype is not jnp.float32 else None, + ) out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) - - # Output projection out_BxLxD = self.output_projection(out_BxLxD) - return out_BxLxD @@ -216,16 +217,16 @@ def __call__(self, in_BxLxD: jax.Array): cfg = self.docfg # x = x + attn( attn_norm(x) ) - x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( - in_BxLxD - ) + x_BxLxD = nn.RMSNorm( + param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon + )(in_BxLxD) x_BxLxD = CausalAttn(cfg)(x_BxLxD) x_BxLxD += in_BxLxD # x = x + mlp( mlp_norm(x) ) - z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( - x_BxLxD - ) + z_BxLxD = nn.RMSNorm( + param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon + )(x_BxLxD) z_BxLxD = Mlp(cfg)(z_BxLxD) return x_BxLxD + z_BxLxD @@ -242,19 +243,24 @@ def setup(self): num_embeddings=cfg.vocab_size, features=cfg.model_dim, embedding_init=cfg.embed_init, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, ) self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)] - self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + self.out_ln = nn.RMSNorm( + param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon + ) # Output projection - tied to input embeddings if configured if cfg.tie_embeddings: - self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + self.output_proj = lambda x: self.embed.attend(x) else: self.output_proj = nn.Dense( cfg.vocab_size, kernel_init=cfg.embed_init, - dtype=cfg.dtype, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, name='output_proj', ) @@ -357,6 +363,7 @@ def main(): # Make a prediction (forward pass) print('\nRunning forward pass...') + params, x_BxL = cfg.mp_policy.cast_to_compute((params, x_BxL)) logits = model.apply(params, x_BxL) # Print output shape and sample values diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py index ee4cffbbc..14366d9ea 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py @@ -1,9 +1,11 @@ """LM workload implemented in Jax.""" +from functools import partial from typing import Any, Dict, Optional, Tuple import jax import jax.numpy as jnp +import jmp from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( @@ -13,10 +15,33 @@ from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload +replicated_sharding = jax_sharding_utils.get_replicate_sharding() +batch_sharding = jax_sharding_utils.get_batch_dim_sharding() + +# Dtype mapping from string to JAX dtype +DTYPE_MAP = { + 'float32': jnp.float32, + 'float16': jnp.float16, + 'bfloat16': jnp.bfloat16, +} + class LmWorkload(BaseLmWorkload): """LM JAX workload.""" + # Convert dtype strings from base class to JAX dtypes + @property + def _compute_dtype(self) -> Any: + return DTYPE_MAP[self._compute_dtype_str] + + @property + def _param_dtype(self) -> Any: + return DTYPE_MAP[self._param_dtype_str] + + @property + def _output_dtype(self) -> Any: + return DTYPE_MAP[self._output_dtype_str] + def _build_input_queue( self, data_rng: jax.random.PRNGKey, @@ -53,8 +78,14 @@ def init_model_fn( num_layers=self._n_layers, # num layers vocab_size=self._vocab_size, expanded_model_dim=self._mlp_dim, # feedforward dim - dtype=jnp.float32, + rmsnorm_epsilon=self._rmsnorm_epsilon, + qknorm_epsilon=self._qknorm_epsilon, + tie_embeddings=self._tie_embeddings, + param_dtype=self._param_dtype, + compute_dtype=self._compute_dtype, + output_dtype=self._output_dtype, ) + self._mp_policy: jmp.Policy = cfg.mp_policy self._model = TransformerDo(cfg) input_shape = (1, self._seq_len) # For token IDs @@ -66,8 +97,7 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) params = jax_sharding_utils.replicate(params) - model_state = None - return params, model_state + return params, None def model_fn( self, @@ -81,10 +111,12 @@ def model_fn( ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] + params, inputs = self._mp_policy.cast_to_compute((params, inputs)) # Convert one-hot inputs to token IDs if needed if inputs.ndim == 3: # one-hot encoded inputs = jnp.argmax(inputs, axis=-1) logits = self._model.apply({'params': params}, inputs) + logits = self._mp_policy.cast_to_output(logits) return logits, None def loss_fn( @@ -139,6 +171,17 @@ def loss_fn( 'per_example': per_example_losses, } + @partial( + jax.jit, + static_argnums=(0,), + in_shardings=( + replicated_sharding, + batch_sharding, + replicated_sharding, + replicated_sharding, + ), + out_shardings=(replicated_sharding), + ) def _eval_batch( self, params: spec.ParameterContainer, diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py index edee8318c..4c60198cc 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py @@ -26,14 +26,24 @@ class ModelConfig: qknorm_epsilon: float = 1e-6 use_residual_scaling: bool = True tie_embeddings: bool = True + compute_dtype: torch.dtype = torch.bfloat16 + param_dtype: torch.dtype = torch.float32 class MLP(nn.Module): - def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + dtype: torch.dtype = torch.float32, + ): super().__init__() - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) - self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + hidden_dim = int( + multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + ) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False, dtype=dtype) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False, dtype=dtype) self.glu = nn.GLU(dim=2) nn.init.normal_(self.fc1.weight, std=0.02) nn.init.normal_(self.fc2.weight, std=0.02) @@ -88,8 +98,12 @@ def __init__(self, cfg: ModelConfig): self.n_heads = cfg.num_heads self.head_dim = cfg.model_dim // cfg.num_heads - self.w_qkv = nn.Linear(cfg.model_dim, 3 * cfg.model_dim, bias=False) - self.w_out = nn.Linear(cfg.model_dim, cfg.model_dim, bias=False) + self.w_qkv = nn.Linear( + cfg.model_dim, 3 * cfg.model_dim, bias=False, dtype=cfg.param_dtype + ) + self.w_out = nn.Linear( + cfg.model_dim, cfg.model_dim, bias=False, dtype=cfg.param_dtype + ) # Split into Q, K, V sections wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) for w in [wq, wk, wv]: @@ -99,7 +113,9 @@ def __init__(self, cfg: ModelConfig): self.eps = cfg.qknorm_epsilon # e.g., 1e-6 seq_len = cfg.seq_len attn_scale0 = math.log2(seq_len**2 - seq_len) - self.attn_scale = nn.Parameter(torch.tensor(attn_scale0)) + self.attn_scale = nn.Parameter( + torch.tensor(attn_scale0, dtype=cfg.param_dtype) + ) def forward(self, x, freqs_cis): bsz, seqlen, d = x.shape # (bsz, seqlen, d) @@ -142,13 +158,18 @@ class Block(nn.Module): def __init__(self, layer_id: int, cfg: ModelConfig): super().__init__() self.attn = Attention(cfg) - self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.attn_norm = nn.RMSNorm( + cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype + ) self.mlp = MLP( dim=cfg.model_dim, hidden_dim=cfg.expanded_model_dim, multiple_of=cfg.multiple_of, + dtype=cfg.param_dtype, + ) + self.mlp_norm = nn.RMSNorm( + cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype ) - self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) self.layer_id = layer_id def forward(self, x, freqs_cis): @@ -166,12 +187,18 @@ def __init__(self, cfg: ModelConfig): head_dim = cfg.model_dim // cfg.num_heads assert cfg.model_dim % cfg.num_heads == 0 - self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.model_dim) + self.embed_tokens = nn.Embedding( + cfg.vocab_size, cfg.model_dim, dtype=cfg.param_dtype + ) self.layers = nn.ModuleList( [Block(idx, cfg) for idx in range(cfg.num_layers)] ) - self.out_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) - self.lm_head = nn.Linear(cfg.model_dim, cfg.vocab_size, bias=False) + self.out_norm = nn.RMSNorm( + cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype + ) + self.lm_head = nn.Linear( + cfg.model_dim, cfg.vocab_size, bias=False, dtype=cfg.param_dtype + ) # Initialize freqs_cis on CPU first (more memory efficient) self.register_buffer( @@ -215,6 +242,7 @@ def forward(self, x, targets=None): for layer in self.layers: x = layer(x, freqs_cis) # (bsz, seqlen, dim) out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + if targets is not None: loss = F.cross_entropy( out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100 @@ -232,40 +260,43 @@ def predict(self, x, k=1): Returns: Tuple of (input_ids, predicted_ids) """ + # Determine device type for autocast + device_type = 'cuda' if x.is_cuda else 'cpu' - # Store original input - original_input = x.clone() - generated_input = x.clone() + with torch.autocast(device_type=device_type, dtype=self.cfg.compute_dtype): + # Store original input + original_input = x.clone() + generated_input = x.clone() - # Generate k tokens autoregressively - for i in range(k): - # Get logits for the entire sequence - logits = self(generated_input) + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) - # Get the logits for the last token in each sequence - next_token_logits = logits[:, -1, :] + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] - # Zero out the last token ID to prevent repetition - # This is a common issue - the model gets stuck repeating the last token - last_token_id = generated_input[:, -1] - next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) - # Get the most likely token - next_token = torch.argmax(next_token_logits, dim=-1) + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) - # Append the predicted token to the sequence - next_token = next_token.unsqueeze(1) # Add sequence dimension - generated_input = torch.cat([generated_input, next_token], dim=1) + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) - # For debugging, print predictions for the first item in the batch - print('\nPyTorch detailed prediction (first item in batch):') - predicted_sequence = generated_input[0, -k:].tolist() - print(f' Predicted token IDs: {predicted_sequence}') - for i, token_id in enumerate(predicted_sequence): - print(f' Step {i + 1}: Predicted token {token_id}') + # For debugging, print predictions for the first item in the batch + print('\nPyTorch detailed prediction (first item in batch):') + predicted_sequence = generated_input[0, -k:].tolist() + print(f' Predicted token IDs: {predicted_sequence}') + for i, token_id in enumerate(predicted_sequence): + print(f' Step {i + 1}: Predicted token {token_id}') - # Return all tokens, not just the last k - return original_input, generated_input[:, -k:] + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -318,6 +349,8 @@ def main(): # Instantiate the model model = Transformer(config) print(f'Model has {model.count_params():,} parameters.') + for n, p in model.named_parameters(): + print(f'{n}.dtype == {p.dtype}') # Create some random input data batch_size = 2 @@ -330,6 +363,7 @@ def main(): # Run a forward pass print(f'Running forward pass with input shape: {input_ids.shape}') logits = model(input_ids) + print(f'Output logits dtype: {logits.dtype}') print(f'Output logits shape: {logits.shape}') # Run prediction diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py index a25ca334a..ed922f9c2 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py @@ -19,10 +19,25 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +# Dtype mapping from string to PyTorch dtype +DTYPE_MAP = { + 'float32': torch.float32, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, +} + class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" + @property + def _compute_dtype(self) -> torch.dtype: + return DTYPE_MAP[self._compute_dtype_str] + + @property + def _param_dtype(self) -> torch.dtype: + return DTYPE_MAP[self._param_dtype_str] + def init_model_fn( self, rng: spec.RandomState, @@ -40,11 +55,14 @@ def init_model_fn( vocab_size=self._vocab_size, seq_len=self._seq_len, model_dim=self._emb_dim, # Model dimension - expanded_model_dim=self._mlp_dim, # MLP expansion factor - num_layers=self._n_layers, # Number of transformer layers - num_heads=self._n_heads, # Number of attention heads - rmsnorm_epsilon=1e-6, - tie_embeddings=True, + expanded_model_dim=self._mlp_dim, # MLP expanded dim + num_layers=self._n_layers, + num_heads=self._n_heads, + rmsnorm_epsilon=self._rmsnorm_epsilon, + qknorm_epsilon=self._qknorm_epsilon, + tie_embeddings=self._tie_embeddings, + compute_dtype=self._compute_dtype, + param_dtype=self._param_dtype, ) self._model = Transformer(cfg) self._param_shapes = param_utils.pytorch_param_shapes(self._model) @@ -81,13 +99,18 @@ def model_fn( spec.ForwardPassMode.EVAL: torch.no_grad, spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } + + # Determine device type for autocast + device_type = 'cuda' if DEVICE.type == 'cuda' else 'cpu' + with contexts[mode](): - # Convert one-hot inputs to token IDs if needed - inputs = augmented_and_preprocessed_input_batch['inputs'] - if inputs.dim() == 3: # one-hot encoded - inputs = inputs.argmax(dim=-1) + with torch.autocast(device_type=device_type, dtype=self._compute_dtype): + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) - logits = model(inputs) + logits = model(inputs) return logits, None @@ -121,7 +144,7 @@ def _build_input_queue( batch['targets'], device=DEVICE, dtype=torch.int64 ), 'weights': torch.tensor( - batch['weights'], device=DEVICE, dtype=torch.float32 + batch['weights'], device=DEVICE, dtype=self._param_dtype ) if batch['weights'] is not None else None, @@ -157,29 +180,35 @@ def loss_fn( - 'n_valid_examples': Scalar tensor with the count of valid (non-masked) examples. - 'per_example': Tensor of shape [batch, length] with individual losses per example. """ - vocab_size = logits_batch.size(-1) - - # Compute cross-entropy loss with label smoothing - per_example_losses = torch.nn.functional.cross_entropy( - logits_batch.view(-1, vocab_size), - label_batch.view(-1), - reduction='none', - label_smoothing=label_smoothing, - ) - per_example_losses = per_example_losses.view_as(label_batch) - - # Apply weights if provided - if mask_batch is not None: - per_example_losses = per_example_losses * mask_batch - - # Calculate number of valid examples - n_valid_examples = ( - mask_batch.sum() - if mask_batch is not None - else torch.tensor( - label_batch.numel(), dtype=torch.float32, device=label_batch.device + # Determine device type for autocast + device_type = 'cuda' if logits_batch.is_cuda else 'cpu' + + with torch.autocast(device_type=device_type, dtype=self._compute_dtype): + vocab_size = logits_batch.size(-1) + + # Compute cross-entropy loss with label smoothing + per_example_losses = torch.nn.functional.cross_entropy( + logits_batch.view(-1, vocab_size), + label_batch.view(-1), + reduction='none', + label_smoothing=label_smoothing, + ) + per_example_losses = per_example_losses.view_as(label_batch) + + # Apply weights if provided + if mask_batch is not None: + per_example_losses = per_example_losses * mask_batch + + # Calculate number of valid examples + n_valid_examples = ( + mask_batch.sum() + if mask_batch is not None + else torch.tensor( + label_batch.numel(), + dtype=self._param_dtype, + device=label_batch.device, + ) ) - ) return { 'summed': per_example_losses.sum(), diff --git a/algoperf/workloads/finewebedu_lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py index 5d6e3d742..d95da48ec 100644 --- a/algoperf/workloads/finewebedu_lm/workload.py +++ b/algoperf/workloads/finewebedu_lm/workload.py @@ -27,6 +27,16 @@ class BaseLmWorkload(spec.Workload): _mlp_dim: int = 4096 warmup_factor: float = 0.1 + # Model configuration + _rmsnorm_epsilon: float = 1e-6 + _qknorm_epsilon: float = 1e-6 + _tie_embeddings: bool = True + + # Dtype configuration (as strings, to be converted by framework-specific subclasses) + _compute_dtype_str: str = 'bfloat16' + _param_dtype_str: str = 'float32' + _output_dtype_str: str = 'bfloat16' # Only used by JAX + def __init__(self) -> None: super().__init__() self._param_shapes = None @@ -85,11 +95,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 31_967 # 8.9 hours + return 31_967 # 8.9 hours @property def eval_period_time_sec(self) -> int: - return 2_571 # approximately 25 evals + return 2_571 # approximately 25 evals @property def step_hint(self) -> int: @@ -164,9 +174,9 @@ def _eval_model_on_split( eval_batch = next(self._eval_iters[split]) metrics = self._eval_batch(params, eval_batch, model_state, rng) for metric_name, metric_value in metrics.items(): - if metric_name not in eval_metrics: - eval_metrics[metric_name] = 0.0 - eval_metrics[metric_name] += metric_value + eval_metrics.update( + {metric_name: eval_metrics.get(metric_name, 0.0) + metric_value} + ) eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) eval_results['ppl'] = np.exp(eval_results['loss']).item() diff --git a/pyproject.toml b/pyproject.toml index 006e7e5cd..e3d86df3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] -lm = ["transformers==4.26", "datasets==3.6.0"] +lm = ["transformers==4.26.0", "datasets==3.6.0"] # Frameworks jax_core_deps = [ @@ -99,6 +99,7 @@ jax_core_deps = [ "chex==0.1.86", "ml_dtypes==0.5.1", "protobuf==4.25.5", + "jmp>=0.0.4" ] jax_cpu = [ "jax==0.7.0",