From b9457e4f417515d4075055124abaece74e80d992 Mon Sep 17 00:00:00 2001 From: Rissy Ran Date: Tue, 27 Jan 2026 23:24:48 +0000 Subject: [PATCH] Integrate sparse attention with flash attention --- src/MaxText/layers/attention_mla.py | 2 - src/MaxText/layers/attention_op.py | 40 ++++++++-- src/maxtext/configs/types.py | 8 +- tests/unit/deepseek32_vs_reference_test.py | 93 +++++++++++++++++----- tests/unit/train_compile_test.py | 3 +- 5 files changed, 113 insertions(+), 33 deletions(-) diff --git a/src/MaxText/layers/attention_mla.py b/src/MaxText/layers/attention_mla.py index 7d28d45fda..7ab6d241c2 100644 --- a/src/MaxText/layers/attention_mla.py +++ b/src/MaxText/layers/attention_mla.py @@ -1025,8 +1025,6 @@ def __call__( inputs_positions=inputs_positions, attention_mask=attention_mask, ) - if index_mask is not None: - index_mask = index_mask[:, None, None, :, :] # [b, 1, 1, q_len, kv_len] if self.config.attention == "paged" and model_mode != MODEL_MODE_TRAIN: unnormalized_out, _, exp_sum = self.ds_paged_attention_op( diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index ac967295a7..a67e62223b 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -881,7 +881,9 @@ def apply_attention( Use `dot_product` instead.""" ) return ( - self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap, sinks), + self.tpu_flash_attention( + query, key, value, decoder_segment_ids, self.attn_logits_soft_cap, sinks, index_mask + ), None, None, ) @@ -1038,6 +1040,7 @@ def tpu_flash_attention( decoder_segment_ids: Array | None, attn_logits_soft_cap: float | None = None, sinks: Array | None = None, + index_mask: Array | None = None, ) -> Array: """TPU Flash Attention.""" @@ -1063,10 +1066,12 @@ def tpu_flash_attention( axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep) axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q_ep) axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv_ep) + index_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH)) else: axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel) axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q) axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv) + index_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH)) global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel @@ -1253,10 +1258,12 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): axis_names_kv, segment_axis_names_q, segment_axis_names_kv, + None, # no sharding for config segment_axis_names_splash_kernel, None, # no sharding for cp_size None, # no sharding for load_balanced_context_parallel sink_axis_names, # sharding align with query heads + index_mask_axis_names, ), out_specs=axis_names_q, check_vma=False, @@ -1267,10 +1274,12 @@ def wrap_flash_attention( value, decoder_segment_ids_q, decoder_segment_ids_kv, + sa_config, splash_kernel, cp_size, load_balanced_context_parallel, sinks, + index_mask, ): # If load_balanced_context_parallel is enabled, reorder the key and value tensors # to ensure that they are contiguous in memory. @@ -1296,10 +1305,25 @@ def wrap_flash_attention( decoder_segment_ids_tuple = None if self.config.use_tokamax_splash: - kernel = partial(splash_kernel, max_logit_value=max_logit_value) - attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))( - query, key, value, decoder_segment_ids_tuple, sinks - ) + if self.config.use_sparse_indexer and index_mask is not None: + # Construct the splash kernel call with dynamic mask + def dynamic_mask_splash_kernel(q, k, v, segment, sinks, mask): + splash_kernel = tokamax_splash_kernel.make_dynamic_splash_mha( + mask=mask, + config=sa_config, + ) + kernel = partial(splash_kernel, max_logit_value=max_logit_value) + return kernel(q, k, v, segment, sinks=sinks) + + # Iterate over batch dimension for (query, key, value, segment, sinks, mask) + attn_fn = jax.vmap(dynamic_mask_splash_kernel, (0, 0, 0, 0, None, 0)) + mask = jnp.isclose(index_mask, 0.0) + attention_output = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, mask) + else: + kernel = partial(splash_kernel, max_logit_value=max_logit_value) + attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))( + query, key, value, decoder_segment_ids_tuple, sinks + ) elif self.config.use_jax_splash: materialized_mask = jnp.asarray(mask[:, :]) attention_output = jax_flash_attention.flash_attention_block_masked( @@ -1337,6 +1361,7 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None): decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q) decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv) sinks = _maybe_shard_with_pspec(sinks, sink_axis_names) + index_mask = _maybe_shard_with_pspec(index_mask, index_mask_axis_names) x = wrap_flash_attention( query, @@ -1344,10 +1369,12 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None): value, decoder_segment_ids_q, decoder_segment_ids_kv, + sa_config, None if self.config.use_jax_splash else splash_kernel, cp_size, load_balanced_context_parallel, sinks, + index_mask, ) x = jnp.transpose(x, axes=(0, 2, 1, 3)) @@ -1639,8 +1666,9 @@ def apply_attention_dot( # Apply index mask, deepseek sparse attention # index mask contains 0.0 for kept tokens and large negative for masked tokens. if index_mask is not None: + # index_mask: from [b, q_len, kv_len] to [b, 1, 1, q_len, kv_len] + index_mask = index_mask[:, None, None, :, :] # attn_weights: [b, n_kv, n_q // n_kv, q_len, kv_len] - # index_mask: [b, 1, 1, q_len, kv_len] attn_weights = apply_mask_to_logits(attn_weights, index_mask) if self.is_partition_in_decode(q_seq_len): diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 91ef50bd26..b3af2e4baf 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2205,8 +2205,12 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de if self.use_sparse_indexer: if self.q_lora_rank == 0: raise NotImplementedError("Sparse indexer has not implemented for q_lora_rank = 0.") - if self.attention not in ("dot_product"): - raise ValueError("Sparse indexer is only supported dot_product attention") + supports_dot_product = self.attention == "dot_product" + supports_flash_splash = self.attention == "flash" and self.use_tokamax_splash + if not (supports_dot_product or supports_flash_splash): + raise NotImplementedError( + "Sparse indexer is only supported dot_product attention or flash attention with tokamax splash." + ) if self.attention_type == AttentionType.CHUNK.value and ( not isinstance(self.chunk_attn_window_size, int) or self.chunk_attn_window_size <= 0 ): diff --git a/tests/unit/deepseek32_vs_reference_test.py b/tests/unit/deepseek32_vs_reference_test.py index 0cb119d9f1..9ae4f70ce3 100644 --- a/tests/unit/deepseek32_vs_reference_test.py +++ b/tests/unit/deepseek32_vs_reference_test.py @@ -82,6 +82,7 @@ class Config: qk_nope_head_dim: int = 128 qk_rope_head_dim: int = 64 v_head_dim: int = 128 + use_tokamax_splash: bool = True # yarn rope_type: str = "yarn" original_max_position_embeddings: int = 4096 @@ -98,7 +99,6 @@ class Config: use_sparse_indexer: bool = True index_n_heads: int = 64 index_head_dim: int = 128 # > qk_rope_head_dim - index_topk: int = 4 class ModelArgs: @@ -107,7 +107,7 @@ class ModelArgs: Maps MaxText Config keys to the specific variable names expected by the reference implementation. """ - def __init__(self, config: Config, max_batch_size: int = 8): + def __init__(self, config: Config, max_batch_size: int = 8, index_topk: int = 4): self.max_batch_size = max_batch_size self.scale_fmt = None self.max_seq_len = config.max_position_embeddings @@ -119,6 +119,7 @@ def __init__(self, config: Config, max_batch_size: int = 8): self.qk_nope_head_dim = config.qk_nope_head_dim self.qk_rope_head_dim = config.qk_rope_head_dim self.v_head_dim = config.v_head_dim + self.use_tokamax_splash = config.use_tokamax_splash # yarn self.original_seq_len = config.original_max_position_embeddings self.rope_theta = float(config.rope_max_timescale) @@ -129,7 +130,7 @@ def __init__(self, config: Config, max_batch_size: int = 8): # indexer self.index_n_heads = config.index_n_heads self.index_head_dim = config.index_head_dim - self.index_topk = config.index_topk + self.index_topk = index_topk # ----------------------------------------------------------------------------- @@ -457,14 +458,14 @@ def rotate_activation(x: torch.Tensor) -> torch.Tensor: class Indexer(torch.nn.Module): # pylint: disable=missing-class-docstring - def __init__(self, args: ModelArgs): + def __init__(self, args: ModelArgs, index_topk: int = 4): super().__init__() self.dim: int = args.dim self.n_heads: int = args.index_n_heads self.n_local_heads = args.index_n_heads // world_size self.head_dim: int = args.index_head_dim self.rope_head_dim: int = args.qk_rope_head_dim - self.index_topk: int = args.index_topk + self.index_topk: int = index_topk self.q_lora_rank: int = args.q_lora_rank self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim) self.wk = Linear(self.dim, self.head_dim) @@ -580,7 +581,7 @@ class MLA(nn.Module): softmax_scale (float): Scaling factor for softmax in attention computation. """ - def __init__(self, args: ModelArgs): + def __init__(self, args: ModelArgs, index_topk: int): super().__init__() self.dim = args.dim self.n_heads = args.n_heads @@ -605,7 +606,7 @@ def __init__(self, args: ModelArgs): mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 self.softmax_scale = self.softmax_scale * mscale * mscale - self.indexer = Indexer(args) + self.indexer = Indexer(args, index_topk) self.register_buffer( "kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False @@ -750,7 +751,7 @@ def get_jax_mla_weights(pt_mla, cfg): } -def get_cfg_and_mesh(config, run_name, dtype, batch_size, seq_len): +def get_cfg_and_mesh(config, run_name, dtype, batch_size, seq_len, attention, index_topk): """Returns MaxText configuration and mesh.""" cfg = pyconfig.initialize( [None, get_test_config_path()], @@ -766,7 +767,8 @@ def get_cfg_and_mesh(config, run_name, dtype, batch_size, seq_len): per_device_batch_size=batch_size, max_target_length=seq_len, max_prefill_predict_length=seq_len, - attention="dot_product", + attention=attention, + index_topk=index_topk, **asdict(config), ) devices_array = maxtext_utils.create_device_mesh(cfg) @@ -785,7 +787,7 @@ def setUp(self): np.random.seed(42) self.dtype = "float32" - self.batch_size = 2 + self.batch_size = 4 self.start_pos = 0 self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) # jax config @@ -861,6 +863,8 @@ def test_indexer_match(self, seq_len=8): dtype=self.dtype, batch_size=self.batch_size, seq_len=self.seq_len, + attention="dot_product", + index_topk=4, ) # Indexer specific RoPE (interleave=False) @@ -906,17 +910,53 @@ class DeepseekV32MLATest(DeepseekTestBase): """Tests for MLA Attention with Sparse Indexing.""" @parameterized.named_parameters( - {"testcase_name": "seq_len=2 (index_topk=4)", "seq_len": 2}, - {"testcase_name": "seq_len=8 (index_topk=4)", "seq_len": 8}, + { + "testcase_name": "dot_product_s2_k4", + "attention": "dot_product", + "seq_len": 2, + "index_topk": 4, + }, + { + "testcase_name": "dot_product_s8_k4", + "attention": "dot_product", + "seq_len": 8, + "index_topk": 4, + }, + { + "testcase_name": "dot_product_s128_k4", + "attention": "dot_product", + "seq_len": 128, + "index_topk": 4, + "check_norm": True, + }, + { + "testcase_name": "dot_product_s128_k128", + "attention": "dot_product", + "seq_len": 128, + "index_topk": 128, + "check_norm": True, + }, + { + "testcase_name": "flash_s128_k4", + "attention": "flash", + "seq_len": 128, + "index_topk": 4, + "check_norm": True, + }, + { + "testcase_name": "flash_s128_k128", + "attention": "flash", + "seq_len": 128, + "index_topk": 128, + "check_norm": True, + }, ) - # index_topk=4 - def test_mla_match(self, seq_len=8): - """Verifies MLA output (train mode) matches PyTorch (MHA mode) with indexer.""" - + def test_mla_parity(self, attention, seq_len, index_topk, check_norm=False): + """Verifies JAX MLA output against the PyTorch reference implementation.""" torch_inputs, jax_inputs = self.get_data(seq_len) # 1. PyTorch Run - pt_mla = MLA(self.pt_args) + pt_mla = MLA(self.pt_args, index_topk) init_torch_weights(pt_mla) pt_mla.eval() @@ -936,6 +976,8 @@ def test_mla_match(self, seq_len=8): dtype=self.dtype, batch_size=self.batch_size, seq_len=self.seq_len, + attention=attention, + index_topk=index_topk, ) jax_mla = attention_mla.MLA( @@ -959,7 +1001,7 @@ def test_mla_match(self, seq_len=8): rope_factor=cfg.rope_factor, max_target_length=self.seq_len, mesh=mesh, - attention_kernel="dot_product", + attention_kernel=attention, inputs_q_shape=(self.batch_size, self.seq_len, cfg.emb_dim), inputs_kv_shape=(self.batch_size, self.seq_len, cfg.emb_dim), rngs=self.nnx_rng, @@ -976,10 +1018,17 @@ def test_mla_match(self, seq_len=8): model_mode=MODEL_MODE_TRAIN, ) - # 3 Compare - print("torch out", pt_out) - print("jax out", jax_out) - np.testing.assert_allclose(to_jax(pt_out), jax_out, rtol=1e-2, atol=1e-2) + # 3. Compare + if check_norm: + expected = to_jax(pt_out) / jnp.linalg.norm(to_jax(pt_out)) + actual = jax_out / jnp.linalg.norm(jax_out) + else: + expected = to_jax(pt_out) + actual = jax_out + + print("torch out", expected) + print("jax out", actual) + np.testing.assert_allclose(expected, actual, rtol=1e-2, atol=1e-2) if __name__ == "__main__": diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 1dd31413dc..4ccb678cc7 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -767,7 +767,8 @@ def test_deepseek32(self): "megablox=True", "per_device_batch_size=1", "max_target_length=1024", - "attention=dot_product", # TODO: update to flash attention when it's available. + "attention=flash", + "use_tokamax_splash=True", "dtype=bfloat16", "weight_dtype=bfloat16", # without_device_limit