From 0ae43dbf8cb28a299ae724fc742b0c5bcddea868 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 10 Sep 2025 05:19:26 -0400 Subject: [PATCH] [Attention] add DCP support for FLASH_ATTN_MLA backend (#24453) Signed-off-by: Lucas Wilkinson Signed-off-by: Matthew Bonanni Co-authored-by: Matthew Bonanni --- .../v1/attention/backends/mla/flashattn_mla.py | 18 ++++++++++++++++-- vllm/v1/worker/gpu_model_runner.py | 3 +++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 12f206637d7c..472095e13615 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, get_flash_attn_version) from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, @@ -98,6 +99,11 @@ class FlashAttnMLAMetadataBuilder( # pre-allocated during capture. self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + # TODO(lucas): Until we add support for the DCP custom masking we need + # to restrict decodes to q_len == 1 when DCP is enabled. + self.__class__.reorder_batch_threshold = 1 \ + if get_dcp_group().world_size > 1 else self.reorder_batch_threshold + def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.fa_aot_schedule: @@ -172,6 +178,7 @@ class FlashAttnMLAMetadataBuilder( class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): + can_return_lse_for_decode: bool = True def __init__( self, @@ -239,7 +246,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): # to prevent invalid grid configuration during graph capture. max_seqlen_q = max(attn_metadata.decode.max_query_len, 1) - o = flash_attn_varlen_func( + attn_out = flash_attn_varlen_func( q=q_pe, k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 @@ -251,9 +258,16 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): block_table=attn_metadata.decode.block_table, softmax_scale=self.scale, causal=True, + return_softmax_lse=self.need_to_return_lse_for_decode, fa_version=3, # only version 3 is supported scheduler_metadata=attn_metadata.decode.scheduler_metadata, num_splits=attn_metadata.decode.max_num_splits, ) - return self._v_up_proj(o) + if self.need_to_return_lse_for_decode: + o, lse = attn_out + # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ] + return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ] + else: + o = attn_out + return o, None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 897c3a621320..944793cad94f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -440,6 +440,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return if self.reorder_batch_threshold is not None: + # NOTE(lucas): currently no backend supports the custom masking + # required for DCP with q_len > 1, so we assert here. Remove this + # assert once the custom mask is support is added to FA3. if self.dcp_world_size > 1: assert self.reorder_batch_threshold == 1, \ "DCP not support reorder_batch_threshold > 1 now."