Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/MaxText/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,8 +1025,6 @@ def __call__(
inputs_positions=inputs_positions,
attention_mask=attention_mask,
)
if index_mask is not None:
index_mask = index_mask[:, None, None, :, :] # [b, 1, 1, q_len, kv_len]

if self.config.attention == "paged" and model_mode != MODEL_MODE_TRAIN:
unnormalized_out, _, exp_sum = self.ds_paged_attention_op(
Expand Down
40 changes: 34 additions & 6 deletions src/MaxText/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,9 @@ def apply_attention(
Use `dot_product` instead."""
)
return (
self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap, sinks),
self.tpu_flash_attention(
query, key, value, decoder_segment_ids, self.attn_logits_soft_cap, sinks, index_mask
),
None,
None,
)
Expand Down Expand Up @@ -1038,6 +1040,7 @@ def tpu_flash_attention(
decoder_segment_ids: Array | None,
attn_logits_soft_cap: float | None = None,
sinks: Array | None = None,
index_mask: Array | None = None,
) -> Array:
"""TPU Flash Attention."""

Expand All @@ -1063,10 +1066,12 @@ def tpu_flash_attention(
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep)
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q_ep)
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv_ep)
index_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH))
else:
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel)
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q)
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv)
index_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH))

global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv
global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel
Expand Down Expand Up @@ -1253,10 +1258,12 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
axis_names_kv,
segment_axis_names_q,
segment_axis_names_kv,
None, # no sharding for config
segment_axis_names_splash_kernel,
None, # no sharding for cp_size
None, # no sharding for load_balanced_context_parallel
sink_axis_names, # sharding align with query heads
index_mask_axis_names,
),
out_specs=axis_names_q,
check_vma=False,
Expand All @@ -1267,10 +1274,12 @@ def wrap_flash_attention(
value,
decoder_segment_ids_q,
decoder_segment_ids_kv,
sa_config,
splash_kernel,
cp_size,
load_balanced_context_parallel,
sinks,
index_mask,
):
# If load_balanced_context_parallel is enabled, reorder the key and value tensors
# to ensure that they are contiguous in memory.
Expand All @@ -1296,10 +1305,25 @@ def wrap_flash_attention(
decoder_segment_ids_tuple = None

if self.config.use_tokamax_splash:
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))(
query, key, value, decoder_segment_ids_tuple, sinks
)
if self.config.use_sparse_indexer and index_mask is not None:
# Construct the splash kernel call with dynamic mask
def dynamic_mask_splash_kernel(q, k, v, segment, sinks, mask):
splash_kernel = tokamax_splash_kernel.make_dynamic_splash_mha(
mask=mask,
config=sa_config,
)
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
return kernel(q, k, v, segment, sinks=sinks)

# Iterate over batch dimension for (query, key, value, segment, sinks, mask)
attn_fn = jax.vmap(dynamic_mask_splash_kernel, (0, 0, 0, 0, None, 0))
mask = jnp.isclose(index_mask, 0.0)
attention_output = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, mask)
else:
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))(
query, key, value, decoder_segment_ids_tuple, sinks
)
elif self.config.use_jax_splash:
materialized_mask = jnp.asarray(mask[:, :])
attention_output = jax_flash_attention.flash_attention_block_masked(
Expand Down Expand Up @@ -1337,17 +1361,20 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q)
decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv)
sinks = _maybe_shard_with_pspec(sinks, sink_axis_names)
index_mask = _maybe_shard_with_pspec(index_mask, index_mask_axis_names)

x = wrap_flash_attention(
query,
key,
value,
decoder_segment_ids_q,
decoder_segment_ids_kv,
sa_config,
None if self.config.use_jax_splash else splash_kernel,
cp_size,
load_balanced_context_parallel,
sinks,
index_mask,
)

x = jnp.transpose(x, axes=(0, 2, 1, 3))
Expand Down Expand Up @@ -1639,8 +1666,9 @@ def apply_attention_dot(
# Apply index mask, deepseek sparse attention
# index mask contains 0.0 for kept tokens and large negative for masked tokens.
if index_mask is not None:
# index_mask: from [b, q_len, kv_len] to [b, 1, 1, q_len, kv_len]
index_mask = index_mask[:, None, None, :, :]
# attn_weights: [b, n_kv, n_q // n_kv, q_len, kv_len]
# index_mask: [b, 1, 1, q_len, kv_len]
attn_weights = apply_mask_to_logits(attn_weights, index_mask)

if self.is_partition_in_decode(q_seq_len):
Expand Down
8 changes: 6 additions & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2205,8 +2205,12 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
if self.use_sparse_indexer:
if self.q_lora_rank == 0:
raise NotImplementedError("Sparse indexer has not implemented for q_lora_rank = 0.")
if self.attention not in ("dot_product"):
raise ValueError("Sparse indexer is only supported dot_product attention")
supports_dot_product = self.attention == "dot_product"
supports_flash_splash = self.attention == "flash" and self.use_tokamax_splash
if not (supports_dot_product or supports_flash_splash):
raise NotImplementedError(
"Sparse indexer is only supported dot_product attention or flash attention with tokamax splash."
)
if self.attention_type == AttentionType.CHUNK.value and (
not isinstance(self.chunk_attn_window_size, int) or self.chunk_attn_window_size <= 0
):
Expand Down
93 changes: 71 additions & 22 deletions tests/unit/deepseek32_vs_reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class Config:
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128
use_tokamax_splash: bool = True
# yarn
rope_type: str = "yarn"
original_max_position_embeddings: int = 4096
Expand All @@ -98,7 +99,6 @@ class Config:
use_sparse_indexer: bool = True
index_n_heads: int = 64
index_head_dim: int = 128 # > qk_rope_head_dim
index_topk: int = 4


class ModelArgs:
Expand All @@ -107,7 +107,7 @@ class ModelArgs:
Maps MaxText Config keys to the specific variable names expected by the reference implementation.
"""

def __init__(self, config: Config, max_batch_size: int = 8):
def __init__(self, config: Config, max_batch_size: int = 8, index_topk: int = 4):
self.max_batch_size = max_batch_size
self.scale_fmt = None
self.max_seq_len = config.max_position_embeddings
Expand All @@ -119,6 +119,7 @@ def __init__(self, config: Config, max_batch_size: int = 8):
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
self.use_tokamax_splash = config.use_tokamax_splash
# yarn
self.original_seq_len = config.original_max_position_embeddings
self.rope_theta = float(config.rope_max_timescale)
Expand All @@ -129,7 +130,7 @@ def __init__(self, config: Config, max_batch_size: int = 8):
# indexer
self.index_n_heads = config.index_n_heads
self.index_head_dim = config.index_head_dim
self.index_topk = config.index_topk
self.index_topk = index_topk


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -457,14 +458,14 @@ def rotate_activation(x: torch.Tensor) -> torch.Tensor:

class Indexer(torch.nn.Module): # pylint: disable=missing-class-docstring

def __init__(self, args: ModelArgs):
def __init__(self, args: ModelArgs, index_topk: int = 4):
super().__init__()
self.dim: int = args.dim
self.n_heads: int = args.index_n_heads
self.n_local_heads = args.index_n_heads // world_size
self.head_dim: int = args.index_head_dim
self.rope_head_dim: int = args.qk_rope_head_dim
self.index_topk: int = args.index_topk
self.index_topk: int = index_topk
self.q_lora_rank: int = args.q_lora_rank
self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
self.wk = Linear(self.dim, self.head_dim)
Expand Down Expand Up @@ -580,7 +581,7 @@ class MLA(nn.Module):
softmax_scale (float): Scaling factor for softmax in attention computation.
"""

def __init__(self, args: ModelArgs):
def __init__(self, args: ModelArgs, index_topk: int):
super().__init__()
self.dim = args.dim
self.n_heads = args.n_heads
Expand All @@ -605,7 +606,7 @@ def __init__(self, args: ModelArgs):
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale

self.indexer = Indexer(args)
self.indexer = Indexer(args, index_topk)

self.register_buffer(
"kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False
Expand Down Expand Up @@ -750,7 +751,7 @@ def get_jax_mla_weights(pt_mla, cfg):
}


def get_cfg_and_mesh(config, run_name, dtype, batch_size, seq_len):
def get_cfg_and_mesh(config, run_name, dtype, batch_size, seq_len, attention, index_topk):
"""Returns MaxText configuration and mesh."""
cfg = pyconfig.initialize(
[None, get_test_config_path()],
Expand All @@ -766,7 +767,8 @@ def get_cfg_and_mesh(config, run_name, dtype, batch_size, seq_len):
per_device_batch_size=batch_size,
max_target_length=seq_len,
max_prefill_predict_length=seq_len,
attention="dot_product",
attention=attention,
index_topk=index_topk,
**asdict(config),
)
devices_array = maxtext_utils.create_device_mesh(cfg)
Expand All @@ -785,7 +787,7 @@ def setUp(self):
np.random.seed(42)

self.dtype = "float32"
self.batch_size = 2
self.batch_size = 4
self.start_pos = 0
self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42))
# jax config
Expand Down Expand Up @@ -861,6 +863,8 @@ def test_indexer_match(self, seq_len=8):
dtype=self.dtype,
batch_size=self.batch_size,
seq_len=self.seq_len,
attention="dot_product",
index_topk=4,
)

# Indexer specific RoPE (interleave=False)
Expand Down Expand Up @@ -906,17 +910,53 @@ class DeepseekV32MLATest(DeepseekTestBase):
"""Tests for MLA Attention with Sparse Indexing."""

@parameterized.named_parameters(
{"testcase_name": "seq_len=2 (index_topk=4)", "seq_len": 2},
{"testcase_name": "seq_len=8 (index_topk=4)", "seq_len": 8},
{
"testcase_name": "dot_product_s2_k4",
"attention": "dot_product",
"seq_len": 2,
"index_topk": 4,
},
{
"testcase_name": "dot_product_s8_k4",
"attention": "dot_product",
"seq_len": 8,
"index_topk": 4,
},
{
"testcase_name": "dot_product_s128_k4",
"attention": "dot_product",
"seq_len": 128,
"index_topk": 4,
"check_norm": True,
},
{
"testcase_name": "dot_product_s128_k128",
"attention": "dot_product",
"seq_len": 128,
"index_topk": 128,
"check_norm": True,
},
{
"testcase_name": "flash_s128_k4",
"attention": "flash",
"seq_len": 128,
"index_topk": 4,
"check_norm": True,
},
{
"testcase_name": "flash_s128_k128",
"attention": "flash",
"seq_len": 128,
"index_topk": 128,
"check_norm": True,
},
)
# index_topk=4
def test_mla_match(self, seq_len=8):
"""Verifies MLA output (train mode) matches PyTorch (MHA mode) with indexer."""

def test_mla_parity(self, attention, seq_len, index_topk, check_norm=False):
"""Verifies JAX MLA output against the PyTorch reference implementation."""
torch_inputs, jax_inputs = self.get_data(seq_len)

# 1. PyTorch Run
pt_mla = MLA(self.pt_args)
pt_mla = MLA(self.pt_args, index_topk)
init_torch_weights(pt_mla)
pt_mla.eval()

Expand All @@ -936,6 +976,8 @@ def test_mla_match(self, seq_len=8):
dtype=self.dtype,
batch_size=self.batch_size,
seq_len=self.seq_len,
attention=attention,
index_topk=index_topk,
)

jax_mla = attention_mla.MLA(
Expand All @@ -959,7 +1001,7 @@ def test_mla_match(self, seq_len=8):
rope_factor=cfg.rope_factor,
max_target_length=self.seq_len,
mesh=mesh,
attention_kernel="dot_product",
attention_kernel=attention,
inputs_q_shape=(self.batch_size, self.seq_len, cfg.emb_dim),
inputs_kv_shape=(self.batch_size, self.seq_len, cfg.emb_dim),
rngs=self.nnx_rng,
Expand All @@ -976,10 +1018,17 @@ def test_mla_match(self, seq_len=8):
model_mode=MODEL_MODE_TRAIN,
)

# 3 Compare
print("torch out", pt_out)
print("jax out", jax_out)
np.testing.assert_allclose(to_jax(pt_out), jax_out, rtol=1e-2, atol=1e-2)
# 3. Compare
if check_norm:
expected = to_jax(pt_out) / jnp.linalg.norm(to_jax(pt_out))
actual = jax_out / jnp.linalg.norm(jax_out)
else:
expected = to_jax(pt_out)
actual = jax_out

print("torch out", expected)
print("jax out", actual)
np.testing.assert_allclose(expected, actual, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/train_compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,8 @@ def test_deepseek32(self):
"megablox=True",
"per_device_batch_size=1",
"max_target_length=1024",
"attention=dot_product", # TODO: update to flash attention when it's available.
"attention=flash",
"use_tokamax_splash=True",
"dtype=bfloat16",
"weight_dtype=bfloat16",
# without_device_limit
Expand Down
Loading