[Attention] add DCP support for FLASH_ATTN_MLA backend (#24453)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-09-10 05:19:26 -04:00 committed by GitHub
parent 267c80d31f
commit 0ae43dbf8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 2 deletions

View File

@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
get_flash_attn_version)
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
@ -98,6 +99,11 @@ class FlashAttnMLAMetadataBuilder(
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
# TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled.
self.__class__.reorder_batch_threshold = 1 \
if get_dcp_group().world_size > 1 else self.reorder_batch_threshold
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.fa_aot_schedule:
@ -172,6 +178,7 @@ class FlashAttnMLAMetadataBuilder(
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
@ -239,7 +246,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
# to prevent invalid grid configuration during graph capture.
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
o = flash_attn_varlen_func(
attn_out = flash_attn_varlen_func(
q=q_pe,
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
@ -251,9 +258,16 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
block_table=attn_metadata.decode.block_table,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=self.need_to_return_lse_for_decode,
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
num_splits=attn_metadata.decode.max_num_splits,
)
return self._v_up_proj(o)
if self.need_to_return_lse_for_decode:
o, lse = attn_out
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
else:
o = attn_out
return o, None

View File

@ -440,6 +440,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
if self.reorder_batch_threshold is not None:
# NOTE(lucas): currently no backend supports the custom masking
# required for DCP with q_len > 1, so we assert here. Remove this
# assert once the custom mask is support is added to FA3.
if self.dcp_world_size > 1:
assert self.reorder_batch_threshold == 1, \
"DCP not support reorder_batch_threshold > 1 now."