Skip to content

Commit 67475a6

Browse files
authored
[DCP][Bugfix][CI] Fix accuracy issue of DCP when using FLASH_ATTN_MLA (#30309)
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
1 parent 9c32df6 commit 67475a6

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,11 @@ def iter_params(self, model_id: str):
123123

124124
CP_TEXT_GENERATION_MODELS = {
125125
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
126+
CPTestSettings.detailed(dcp_multipliers=[1]),
126127
CPTestSettings.detailed(
127-
dcp_multipliers=[0.5, 1], cp_kv_cache_interleave_size=64
128+
dcp_multipliers=[0.5],
129+
cp_kv_cache_interleave_size=64,
130+
attn_backend="FLASHMLA",
128131
),
129132
],
130133
"Qwen/Qwen2.5-1.5B-Instruct": [

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,14 @@ def __init__(
105105
vllm_config: VllmConfig,
106106
device: torch.device,
107107
):
108+
interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
108109
super().__init__(
109110
kv_cache_spec,
110111
layer_names,
111112
vllm_config,
112113
device,
113114
FlashAttnMLAMetadata,
114-
supports_dcp_with_varlen=True,
115+
supports_dcp_with_varlen=(interleave_size == 1),
115116
)
116117
self.max_num_splits = 0 # No upper bound on the number of splits.
117118
self.fa_aot_schedule = get_flash_attn_version() == 3

0 commit comments

Comments
 (0)