Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 49 additions & 42 deletions algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import jax
import jax.numpy as jnp
import jmp
from flax import linen as nn


Expand All @@ -26,18 +27,24 @@ class ModelConfig:
use_residual_scaling: bool = True
tie_embeddings: bool = True # Whether to tie input and output embed
qknorm_epsilon: float = 1e-6

dtype: jnp.dtype = jnp.float32
attention_init: nn.initializers.Initializer = nn.initializers.normal(
stddev=0.02
)
linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
param_dtype: jnp.dtype = jnp.float32
compute_dtype: jnp.dtype = jnp.bfloat16
output_dtype: jnp.dtype = jnp.bfloat16

def __post_init__(self):
self.residual_init = nn.initializers.normal(
stddev=0.02 / jnp.sqrt(2 * self.num_layers)
)
self.mp_policy = jmp.Policy(
compute_dtype=self.compute_dtype,
param_dtype=self.param_dtype,
output_dtype=self.output_dtype,
)


class Mlp(nn.Module):
Expand All @@ -49,7 +56,11 @@ class Mlp(nn.Module):
def __call__(self, x_BxLxD: jax.Array):
cfg = self.cfg
linear = partial(
nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype
nn.Dense,
kernel_init=cfg.linear_init,
use_bias=False,
dtype=cfg.compute_dtype,
param_dtype=cfg.param_dtype,
)
# Adjust hidden dimension to keep the number of parameters invariant to
# the activation function used since the GLU MLP has 3 * hidden_dim * D
Expand All @@ -65,7 +76,8 @@ def __call__(self, x_BxLxD: jax.Array):
x_BxLxD = nn.Dense(
cfg.model_dim,
use_bias=False,
dtype=cfg.dtype,
dtype=cfg.compute_dtype,
param_dtype=cfg.param_dtype,
kernel_init=cfg.residual_init
if cfg.use_residual_scaling
else cfg.linear_init,
Expand Down Expand Up @@ -96,7 +108,7 @@ def apply_rope(q, k, freqs_cis):

def rotate_tensor(x):
# Split into real and imaginary parts
x_r2 = x.reshape(*x.shape[:-1], -1, 2)
x_r2 = x.reshape(*x.shape[:-1], -1, 2).astype(jnp.float32)
L = x.shape[1]
freqs = freqs_cis[:, :L, :, :, :]

Expand All @@ -109,7 +121,7 @@ def rotate_tensor(x):
axis=-1,
)

return rotated_x_r2.reshape(*x.shape)
return rotated_x_r2.reshape(*x.shape).astype(x.dtype)

# Apply rotation to Q and K separately
rotated_q = rotate_tensor(q)
Expand Down Expand Up @@ -141,7 +153,8 @@ def setup(self):
features=(cfg.num_heads, self.Dh),
kernel_init=cfg.attention_init,
use_bias=False,
dtype=cfg.dtype,
dtype=cfg.compute_dtype,
param_dtype=cfg.param_dtype,
)
self.multilinear_query = self.multilinear(name='query')
self.multilinear_key = self.multilinear(name='key')
Expand All @@ -150,7 +163,9 @@ def setup(self):
seq_len = cfg.seq_len
attn_scale0 = jnp.log2(seq_len**2 - seq_len)
self.attn_scale = self.param(
'attn_scale', nn.initializers.constant(attn_scale0), ()
'attn_scale',
nn.initializers.constant(attn_scale0, dtype=cfg.compute_dtype),
(),
)
self.output_projection = nn.DenseGeneral(
features=cfg.model_dim,
Expand All @@ -160,7 +175,8 @@ def setup(self):
if cfg.use_residual_scaling
else cfg.linear_init,
use_bias=False,
dtype=cfg.dtype,
dtype=cfg.compute_dtype,
param_dtype=cfg.param_dtype,
)

def __call__(self, x_BxLxD: jax.Array):
Expand All @@ -177,32 +193,17 @@ def __call__(self, x_BxLxD: jax.Array):
# Apply QK normalization
q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps
k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps

# Compute attention scores
att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh)

# Causal attention mask
L = x_BxLxD.shape[1]
mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_))

# Apply mask and softmax
_NEG_INF = jnp.finfo(cfg.dtype).min
att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF)
att_BxHxLxL = (
self.attn_scale * att_BxHxLxL
) # Learned scaling factor for QK norm
att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1)
att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype)

# Compute attention output
out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh)

# Reshape and project output
q_BxLxHxDh *= self.attn_scale
out_BxLxHxDh = jax.nn.dot_product_attention(
query=q_BxLxHxDh,
key=k_BxLxHxDh,
value=v_BxLxHxDh,
is_causal=True,
scale=1.0,
implementation='cudnn' if cfg.compute_dtype is not jnp.float32 else None,
)
out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape)

# Output projection
out_BxLxD = self.output_projection(out_BxLxD)

return out_BxLxD


Expand All @@ -216,16 +217,16 @@ def __call__(self, in_BxLxD: jax.Array):
cfg = self.docfg

# x = x + attn( attn_norm(x) )
x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)(
in_BxLxD
)
x_BxLxD = nn.RMSNorm(
param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon
)(in_BxLxD)
x_BxLxD = CausalAttn(cfg)(x_BxLxD)
x_BxLxD += in_BxLxD

# x = x + mlp( mlp_norm(x) )
z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)(
x_BxLxD
)
z_BxLxD = nn.RMSNorm(
param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon
)(x_BxLxD)
z_BxLxD = Mlp(cfg)(z_BxLxD)

return x_BxLxD + z_BxLxD
Expand All @@ -242,19 +243,24 @@ def setup(self):
num_embeddings=cfg.vocab_size,
features=cfg.model_dim,
embedding_init=cfg.embed_init,
dtype=cfg.compute_dtype,
param_dtype=cfg.param_dtype,
)

self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)]
self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)
self.out_ln = nn.RMSNorm(
param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon
)

# Output projection - tied to input embeddings if configured
if cfg.tie_embeddings:
self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32))
self.output_proj = lambda x: self.embed.attend(x)
else:
self.output_proj = nn.Dense(
cfg.vocab_size,
kernel_init=cfg.embed_init,
dtype=cfg.dtype,
dtype=cfg.compute_dtype,
param_dtype=cfg.param_dtype,
name='output_proj',
)

Expand Down Expand Up @@ -357,6 +363,7 @@ def main():

# Make a prediction (forward pass)
print('\nRunning forward pass...')
params, x_BxL = cfg.mp_policy.cast_to_compute((params, x_BxL))
logits = model.apply(params, x_BxL)

# Print output shape and sample values
Expand Down
49 changes: 46 additions & 3 deletions algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""LM workload implemented in Jax."""

from functools import partial
from typing import Any, Dict, Optional, Tuple

import jax
import jax.numpy as jnp
import jmp

from algoperf import jax_sharding_utils, param_utils, spec
from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import (
Expand All @@ -13,10 +15,33 @@
from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter
from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload

replicated_sharding = jax_sharding_utils.get_replicate_sharding()
batch_sharding = jax_sharding_utils.get_batch_dim_sharding()

# Dtype mapping from string to JAX dtype
DTYPE_MAP = {
'float32': jnp.float32,
'float16': jnp.float16,
'bfloat16': jnp.bfloat16,
}


class LmWorkload(BaseLmWorkload):
"""LM JAX workload."""

# Convert dtype strings from base class to JAX dtypes
@property
def _compute_dtype(self) -> Any:
return DTYPE_MAP[self._compute_dtype_str]

@property
def _param_dtype(self) -> Any:
return DTYPE_MAP[self._param_dtype_str]

@property
def _output_dtype(self) -> Any:
return DTYPE_MAP[self._output_dtype_str]

def _build_input_queue(
self,
data_rng: jax.random.PRNGKey,
Expand Down Expand Up @@ -53,8 +78,14 @@ def init_model_fn(
num_layers=self._n_layers, # num layers
vocab_size=self._vocab_size,
expanded_model_dim=self._mlp_dim, # feedforward dim
dtype=jnp.float32,
rmsnorm_epsilon=self._rmsnorm_epsilon,
qknorm_epsilon=self._qknorm_epsilon,
tie_embeddings=self._tie_embeddings,
param_dtype=self._param_dtype,
compute_dtype=self._compute_dtype,
output_dtype=self._output_dtype,
)
self._mp_policy: jmp.Policy = cfg.mp_policy
self._model = TransformerDo(cfg)
input_shape = (1, self._seq_len) # For token IDs

Expand All @@ -66,8 +97,7 @@ def init_model_fn(
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
params = jax_sharding_utils.replicate(params)
model_state = None
return params, model_state
return params, None

def model_fn(
self,
Expand All @@ -81,10 +111,12 @@ def model_fn(
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode, rng, update_batch_norm, model_state, dropout_rate
inputs = batch['inputs']
params, inputs = self._mp_policy.cast_to_compute((params, inputs))
# Convert one-hot inputs to token IDs if needed
if inputs.ndim == 3: # one-hot encoded
inputs = jnp.argmax(inputs, axis=-1)
logits = self._model.apply({'params': params}, inputs)
logits = self._mp_policy.cast_to_output(logits)
return logits, None

def loss_fn(
Expand Down Expand Up @@ -139,6 +171,17 @@ def loss_fn(
'per_example': per_example_losses,
}

@partial(
jax.jit,
static_argnums=(0,),
in_shardings=(
replicated_sharding,
batch_sharding,
replicated_sharding,
replicated_sharding,
),
out_shardings=(replicated_sharding),
)
def _eval_batch(
self,
params: spec.ParameterContainer,
Expand Down
Loading
Loading