[Attention][DCP] Support DCP with query length > 1 (MTP) with FA3 (#25049)

Signed-off-by: Ming Yang <minos.future@gmail.com>
This commit is contained in:
Ming Yang 2025-10-09 08:06:29 -07:00 committed by GitHub
parent 2c1c7dfb35
commit 3b736e1c38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 45 additions and 13 deletions

View File

@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 4695e6bed5366c41e28c06cd86170166e4f43d00
GIT_TAG 8f468e7da54a8e2f98abfa7c38636aac91c0cba1
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@ -370,6 +370,7 @@ class CudnnPrefillMetadata(MLACommonPrefillMetadata):
class MLACommonDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
dcp_tot_seq_lens: Optional[torch.Tensor]
D = TypeVar("D", bound=MLACommonDecodeMetadata)
@ -682,10 +683,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: Optional[torch.Tensor],
) -> MLACommonDecodeMetadata:
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)
def build_for_cudagraph_capture(
@ -727,6 +730,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
@ -742,7 +746,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# Note(hc): update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
seq_lens[:num_decodes] = seq_lens[:num_decodes] // self.dcp_world_size + (
assert dcp_local_seq_lens is not None
dcp_local_seq_lens[:num_decodes] = seq_lens[
:num_decodes
] // self.dcp_world_size + (
self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size
)
@ -899,10 +906,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
seq_lens_device=dcp_local_seq_lens[:num_decodes]
if self.dcp_world_size > 1 and dcp_local_seq_lens is not None
else seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
query_start_loc_device=query_start_loc[: num_decodes + 1],
num_decode_tokens=num_decode_tokens,
dcp_tot_seq_lens_device=seq_lens[:num_decodes]
if self.dcp_world_size > 1
else None,
)
attn_metadata = self.metadata_cls(

View File

@ -17,7 +17,6 @@ from vllm.attention.utils.fa_utils import (
get_flash_attn_version,
)
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
@ -107,12 +106,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# pre-allocated during capture.
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
# TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled.
self.reorder_batch_threshold = (
1 if get_dcp_group().world_size > 1 else self.reorder_batch_threshold
)
def _schedule_decode(
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
):
@ -121,7 +114,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
batch_size=num_reqs,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
num_heads_q=self.num_heads,
num_heads_q=self.num_heads * self.dcp_world_size,
num_heads_kv=1,
headdim=self.mla_dims.qk_rope_head_dim,
cache_seqlens=seqlens,
@ -142,10 +135,11 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: Optional[torch.Tensor],
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
max_seq_len = seq_lens_device.max().item()
scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(),
@ -188,6 +182,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
max_seq_len=max_seq_len,
scheduler_metadata=scheduler_metadata,
max_num_splits=max_num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)
@ -289,6 +284,9 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
num_splits=attn_metadata.decode.max_num_splits,
cp_world_size=self.dcp_world_size,
cp_rank=self.dcp_rank,
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
)
if self.need_to_return_lse_for_decode:

View File

@ -106,6 +106,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: Optional[torch.Tensor],
) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = get_mla_metadata(
seq_lens_device,
@ -146,6 +147,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
seq_lens=seq_lens_device,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)

View File

@ -116,6 +116,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
dcp_tot_seq_lens_device: Optional[torch.Tensor],
) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
@ -174,6 +175,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)
return attn_metadata

View File

@ -93,6 +93,9 @@ class CommonAttentionMetadata:
# Needed by CrossAttentionBuilder
encoder_seq_lens: Optional[np.ndarray] = None
dcp_local_seq_lens: Optional[torch.Tensor] = None
"""Sequence lengths of the local rank in decode context parallelism world"""
def slice_query_start_locs(
query_start_loc: torch.Tensor,

View File

@ -597,6 +597,7 @@ class EagleProposer:
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
causal=True,
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
)
token_indices_to_sample = (
@ -868,6 +869,7 @@ class EagleProposer:
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
causal=True,
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
)
return spec_common_attn_metadata, token_indices

View File

@ -398,6 +398,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.max_num_reqs + 1, dtype=torch.int32
)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
if self.dcp_world_size > 1:
self.dcp_local_seq_lens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
)
# Because inputs_embeds may be bfloat16 and we don't need a numpy
# version of this tensor, avoid a RuntimeError by not creating a
# numpy buffer.
@ -581,7 +585,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# NOTE(lucas): currently no backend supports the custom masking
# required for DCP with q_len > 1, so we assert here. Remove this
# assert once the custom mask is support is added to FA3.
if self.dcp_world_size > 1:
if (
self.dcp_world_size > 1
and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA"
):
assert self.reorder_batch_threshold == 1, (
"DCP not support reorder_batch_threshold > 1 now."
)
@ -1335,6 +1342,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_logits_indices=logits_indices.size(0),
causal=True,
encoder_seq_lens=encoder_seq_lens,
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
if self.dcp_world_size > 1
else None,
)
if self.speculative_config and spec_decode_common_attn_metadata is None:
@ -3310,6 +3320,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_group_id
].slot_mapping.gpu[:num_tokens],
causal=True,
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
if self.dcp_world_size > 1
else None,
)
for attn_group in self.attn_groups[kv_cache_group_id]:
if ubatch_slices is not None: