mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 01:54:28 +08:00
[Attention][DCP] Support DCP with query length > 1 (MTP) with FA3 (#25049)
Signed-off-by: Ming Yang <minos.future@gmail.com>
This commit is contained in:
parent
2c1c7dfb35
commit
3b736e1c38
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 4695e6bed5366c41e28c06cd86170166e4f43d00
|
||||
GIT_TAG 8f468e7da54a8e2f98abfa7c38636aac91c0cba1
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -370,6 +370,7 @@ class CudnnPrefillMetadata(MLACommonPrefillMetadata):
|
||||
class MLACommonDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
dcp_tot_seq_lens: Optional[torch.Tensor]
|
||||
|
||||
|
||||
D = TypeVar("D", bound=MLACommonDecodeMetadata)
|
||||
@ -682,10 +683,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: Optional[torch.Tensor],
|
||||
) -> MLACommonDecodeMetadata:
|
||||
return MLACommonDecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
@ -727,6 +730,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
@ -742,7 +746,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
# Note(hc): update seq_lens of decode reqs under DCP.
|
||||
if self.dcp_world_size > 1:
|
||||
seq_lens[:num_decodes] = seq_lens[:num_decodes] // self.dcp_world_size + (
|
||||
assert dcp_local_seq_lens is not None
|
||||
dcp_local_seq_lens[:num_decodes] = seq_lens[
|
||||
:num_decodes
|
||||
] // self.dcp_world_size + (
|
||||
self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size
|
||||
)
|
||||
|
||||
@ -899,10 +906,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||
seq_lens_cpu=seq_lens_cpu[:num_decodes],
|
||||
seq_lens_device=seq_lens[:num_decodes],
|
||||
seq_lens_device=dcp_local_seq_lens[:num_decodes]
|
||||
if self.dcp_world_size > 1 and dcp_local_seq_lens is not None
|
||||
else seq_lens[:num_decodes],
|
||||
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
|
||||
query_start_loc_device=query_start_loc[: num_decodes + 1],
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
dcp_tot_seq_lens_device=seq_lens[:num_decodes]
|
||||
if self.dcp_world_size > 1
|
||||
else None,
|
||||
)
|
||||
|
||||
attn_metadata = self.metadata_cls(
|
||||
|
||||
@ -17,7 +17,6 @@ from vllm.attention.utils.fa_utils import (
|
||||
get_flash_attn_version,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
@ -107,12 +106,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
|
||||
# TODO(lucas): Until we add support for the DCP custom masking we need
|
||||
# to restrict decodes to q_len == 1 when DCP is enabled.
|
||||
self.reorder_batch_threshold = (
|
||||
1 if get_dcp_group().world_size > 1 else self.reorder_batch_threshold
|
||||
)
|
||||
|
||||
def _schedule_decode(
|
||||
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
||||
):
|
||||
@ -121,7 +114,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
batch_size=num_reqs,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_seq_len,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_q=self.num_heads * self.dcp_world_size,
|
||||
num_heads_kv=1,
|
||||
headdim=self.mla_dims.qk_rope_head_dim,
|
||||
cache_seqlens=seqlens,
|
||||
@ -142,10 +135,11 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: Optional[torch.Tensor],
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
max_seq_len = seq_lens_cpu.max().item()
|
||||
max_seq_len = seq_lens_device.max().item()
|
||||
|
||||
scheduler_metadata = self._schedule_decode(
|
||||
num_reqs=seq_lens_cpu.numel(),
|
||||
@ -188,6 +182,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
max_seq_len=max_seq_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
|
||||
@ -289,6 +284,9 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
fa_version=3, # only version 3 is supported
|
||||
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.max_num_splits,
|
||||
cp_world_size=self.dcp_world_size,
|
||||
cp_rank=self.dcp_rank,
|
||||
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
|
||||
)
|
||||
|
||||
if self.need_to_return_lse_for_decode:
|
||||
|
||||
@ -106,6 +106,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: Optional[torch.Tensor],
|
||||
) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens_device,
|
||||
@ -146,6 +147,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
seq_lens=seq_lens_device,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -116,6 +116,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: Optional[torch.Tensor],
|
||||
) -> AiterMLADecodeMetadata:
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
|
||||
@ -174,6 +175,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
@ -93,6 +93,9 @@ class CommonAttentionMetadata:
|
||||
# Needed by CrossAttentionBuilder
|
||||
encoder_seq_lens: Optional[np.ndarray] = None
|
||||
|
||||
dcp_local_seq_lens: Optional[torch.Tensor] = None
|
||||
"""Sequence lengths of the local rank in decode context parallelism world"""
|
||||
|
||||
|
||||
def slice_query_start_locs(
|
||||
query_start_loc: torch.Tensor,
|
||||
|
||||
@ -597,6 +597,7 @@ class EagleProposer:
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
||||
causal=True,
|
||||
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
token_indices_to_sample = (
|
||||
@ -868,6 +869,7 @@ class EagleProposer:
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
||||
causal=True,
|
||||
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
return spec_common_attn_metadata, token_indices
|
||||
|
||||
@ -398,6 +398,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.max_num_reqs + 1, dtype=torch.int32
|
||||
)
|
||||
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
if self.dcp_world_size > 1:
|
||||
self.dcp_local_seq_lens = self._make_buffer(
|
||||
self.max_num_reqs, dtype=torch.int32
|
||||
)
|
||||
# Because inputs_embeds may be bfloat16 and we don't need a numpy
|
||||
# version of this tensor, avoid a RuntimeError by not creating a
|
||||
# numpy buffer.
|
||||
@ -581,7 +585,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# NOTE(lucas): currently no backend supports the custom masking
|
||||
# required for DCP with q_len > 1, so we assert here. Remove this
|
||||
# assert once the custom mask is support is added to FA3.
|
||||
if self.dcp_world_size > 1:
|
||||
if (
|
||||
self.dcp_world_size > 1
|
||||
and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA"
|
||||
):
|
||||
assert self.reorder_batch_threshold == 1, (
|
||||
"DCP not support reorder_batch_threshold > 1 now."
|
||||
)
|
||||
@ -1335,6 +1342,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_logits_indices=logits_indices.size(0),
|
||||
causal=True,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
|
||||
if self.dcp_world_size > 1
|
||||
else None,
|
||||
)
|
||||
|
||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||
@ -3310,6 +3320,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_group_id
|
||||
].slot_mapping.gpu[:num_tokens],
|
||||
causal=True,
|
||||
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
|
||||
if self.dcp_world_size > 1
|
||||
else None,
|
||||
)
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
if ubatch_slices is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user