diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 3cb533dccd62c..aa47f28a34dd5 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -123,8 +123,11 @@ class CPTestSettings: CP_TEXT_GENERATION_MODELS = { "deepseek-ai/DeepSeek-V2-Lite-Chat": [ + CPTestSettings.detailed(dcp_multipliers=[1]), CPTestSettings.detailed( - dcp_multipliers=[0.5, 1], cp_kv_cache_interleave_size=64 + dcp_multipliers=[0.5], + cp_kv_cache_interleave_size=64, + attn_backend="FLASHMLA", ), ], "Qwen/Qwen2.5-1.5B-Instruct": [ diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index eccf4ec791095..b28814aceada9 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -105,13 +105,14 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] vllm_config: VllmConfig, device: torch.device, ): + interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size super().__init__( kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata, - supports_dcp_with_varlen=True, + supports_dcp_with_varlen=(interleave_size == 1), ) self.max_num_splits = 0 # No upper bound on the number of splits. self.fa_aot_schedule = get_flash_attn_version() == 3