mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:06:10 +08:00
[Attention] add DCP support for FLASH_ATTN_MLA backend (#24453)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
267c80d31f
commit
0ae43dbf8c
@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
|
||||
get_flash_attn_version)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
@ -98,6 +99,11 @@ class FlashAttnMLAMetadataBuilder(
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
|
||||
# TODO(lucas): Until we add support for the DCP custom masking we need
|
||||
# to restrict decodes to q_len == 1 when DCP is enabled.
|
||||
self.__class__.reorder_batch_threshold = 1 \
|
||||
if get_dcp_group().world_size > 1 else self.reorder_batch_threshold
|
||||
|
||||
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
if self.fa_aot_schedule:
|
||||
@ -172,6 +178,7 @@ class FlashAttnMLAMetadataBuilder(
|
||||
|
||||
|
||||
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -239,7 +246,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
# to prevent invalid grid configuration during graph capture.
|
||||
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
|
||||
|
||||
o = flash_attn_varlen_func(
|
||||
attn_out = flash_attn_varlen_func(
|
||||
q=q_pe,
|
||||
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
|
||||
@ -251,9 +258,16 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_softmax_lse=self.need_to_return_lse_for_decode,
|
||||
fa_version=3, # only version 3 is supported
|
||||
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.max_num_splits,
|
||||
)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
if self.need_to_return_lse_for_decode:
|
||||
o, lse = attn_out
|
||||
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
|
||||
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
|
||||
else:
|
||||
o = attn_out
|
||||
return o, None
|
||||
|
||||
@ -440,6 +440,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return
|
||||
|
||||
if self.reorder_batch_threshold is not None:
|
||||
# NOTE(lucas): currently no backend supports the custom masking
|
||||
# required for DCP with q_len > 1, so we assert here. Remove this
|
||||
# assert once the custom mask is support is added to FA3.
|
||||
if self.dcp_world_size > 1:
|
||||
assert self.reorder_batch_threshold == 1, \
|
||||
"DCP not support reorder_batch_threshold > 1 now."
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user