mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 05:24:54 +08:00
[DCP] Support Decode Context Parallel (DCP) for GQA with FlashAttention (#24864)
Signed-off-by: yuanyongjie.yyj <yuanyongjie.yyj@antgroup.com> Signed-off-by: FENP <32334296+FENP@users.noreply.github.com> Signed-off-by: Jaya Yuan <yuanyongjie.yyj@antgroup.com>
This commit is contained in:
parent
fdd32750f0
commit
ea97940d6c
@ -204,17 +204,21 @@ def _compare_cp_with_tp(
|
||||
|
||||
|
||||
CP_TEXT_GENERATION_MODELS = {
|
||||
# [MLA attention only]
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||
CPTestSettings.detailed(),
|
||||
CPTestSettings.detailed(tp_base=2),
|
||||
],
|
||||
"bigcode/gpt_bigcode-santacoder": [
|
||||
CPTestSettings.detailed(),
|
||||
CPTestSettings.detailed(tp_base=2),
|
||||
],
|
||||
}
|
||||
|
||||
CP_TEST_MODELS = [
|
||||
# TODO support other models
|
||||
# [LANGUAGE GENERATION]
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
"bigcode/gpt_bigcode-santacoder",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -262,7 +262,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
|
||||
"GPTBigCodeForCausalLM": _HfExamplesInfo(
|
||||
"bigcode/starcoder",
|
||||
extras={"tiny": "bigcode/tiny_starcoder_py"},
|
||||
extras={
|
||||
"tiny": "bigcode/tiny_starcoder_py",
|
||||
"santacoder": "bigcode/gpt_bigcode-santacoder",
|
||||
},
|
||||
min_transformers_version="4.55.1",
|
||||
transformers_version_reason="HF model broken in 4.55.0",
|
||||
),
|
||||
|
||||
@ -173,6 +173,7 @@ def cp_lse_ag_out_rs(
|
||||
cp_attn_lse: torch.Tensor,
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext = None,
|
||||
return_lse=False,
|
||||
):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
@ -192,8 +193,15 @@ def cp_lse_ag_out_rs(
|
||||
|
||||
cp_attn_lse = cp_attn_lse.contiguous()
|
||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
assert out.is_contiguous()
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
|
||||
if return_lse:
|
||||
cp_num_heads = lse.shape[1] // cp_group.world_size
|
||||
cp_rank = cp_group.rank_in_group
|
||||
lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)]
|
||||
return out, lse
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -1202,6 +1202,23 @@ class ModelConfig:
|
||||
"Supported models implement the `SupportsPP` interface."
|
||||
)
|
||||
|
||||
decode_context_parallel_size = parallel_config.decode_context_parallel_size
|
||||
if decode_context_parallel_size > 1 and not self.use_mla:
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
assert tensor_parallel_size > total_num_kv_heads, (
|
||||
f"tensor parallel size {tensor_parallel_size} must be greater "
|
||||
f"than total num kv heads {total_num_kv_heads} when enable "
|
||||
f"decode context parallel for GQA/MQA"
|
||||
)
|
||||
|
||||
max_dcp_size = tensor_parallel_size // total_num_kv_heads
|
||||
assert decode_context_parallel_size <= max_dcp_size, (
|
||||
f"decode context parallel size must less than or equal to "
|
||||
f"(tensor parallel size {tensor_parallel_size} // total "
|
||||
f"num kv heads {total_num_kv_heads}) = {max_dcp_size}, "
|
||||
f"but got {decode_context_parallel_size}"
|
||||
)
|
||||
|
||||
def get_sliding_window(self) -> int | None:
|
||||
"""Get the sliding window size from the HF text config if present."""
|
||||
return getattr(self.hf_text_config, "sliding_window", None)
|
||||
|
||||
@ -17,6 +17,7 @@ from vllm.attention.backends.abstract import (
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_fp8,
|
||||
@ -32,6 +33,7 @@ if is_flash_attn_varlen_func_available():
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
@ -147,6 +149,10 @@ class FlashAttentionMetadata:
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
# For GQA DCP
|
||||
max_dcp_context_kv_len: int | None = None
|
||||
dcp_context_kv_lens: torch.Tensor | None = None
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
@ -216,6 +222,16 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
self.aot_schedule = get_flash_attn_version() == 3
|
||||
|
||||
try:
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
@ -306,7 +322,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
batch_size=batch_size,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_seq_len,
|
||||
num_heads_q=self.num_heads_q,
|
||||
num_heads_q=self.num_heads_q * self.dcp_world_size,
|
||||
num_heads_kv=self.num_heads_kv,
|
||||
headdim=self.headdim,
|
||||
cache_seqlens=seqlens,
|
||||
@ -320,8 +336,35 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
return None
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
max_dcp_context_kv_len = 0
|
||||
dcp_context_kv_lens = None
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
query_kv_lens_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[1:]
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
|
||||
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + (
|
||||
self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size
|
||||
)
|
||||
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
|
||||
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
|
||||
|
||||
scheduler_metadata = schedule(
|
||||
batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=dcp_context_kv_lens,
|
||||
max_seq_len=max_dcp_context_kv_len,
|
||||
causal=False,
|
||||
)
|
||||
elif use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor(
|
||||
[0, num_actual_tokens], dtype=torch.int32, device=self.device
|
||||
)
|
||||
@ -348,10 +391,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
causal=True,
|
||||
)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
scheduler_metadata = schedule(
|
||||
batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
@ -379,6 +418,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
max_dcp_context_kv_len=max_dcp_context_kv_len,
|
||||
dcp_context_kv_lens=dcp_context_kv_lens,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
@ -396,6 +437,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@ -562,30 +605,45 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
||||
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
num_splits=attn_metadata.max_num_splits,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
return output
|
||||
if self.dcp_world_size > 1:
|
||||
self._forward_with_dcp(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
else:
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
num_splits=attn_metadata.max_num_splits,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
return output
|
||||
|
||||
# Cascade attention (rare case).
|
||||
cascade_attention(
|
||||
@ -615,6 +673,86 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
return output
|
||||
|
||||
def _forward_with_dcp(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
q_descale: torch.Tensor | None = None,
|
||||
k_descale: torch.Tensor | None = None,
|
||||
v_descale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
query = query.contiguous()
|
||||
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
|
||||
context_attn_out, context_lse = flash_attn_varlen_func(
|
||||
q=query_across_dcp,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=None,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=attn_metadata.dcp_context_kv_lens,
|
||||
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
|
||||
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
|
||||
context_attn_out,
|
||||
context_lse.transpose(0, 1),
|
||||
get_dcp_group(),
|
||||
return_lse=True,
|
||||
)
|
||||
context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()
|
||||
|
||||
query_attn_out, query_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
out=None,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
cu_seqlens_k=cu_seqlens_q,
|
||||
max_seqlen_k=max_seqlen_q,
|
||||
softmax_scale=self.scale,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
softcap=self.logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
assert context_attn_out_cor.shape == query_attn_out.shape
|
||||
assert context_lse_cor.shape == query_lse.shape
|
||||
merge_attn_states(
|
||||
output,
|
||||
context_attn_out_cor,
|
||||
context_lse_cor,
|
||||
query_attn_out,
|
||||
query_lse,
|
||||
)
|
||||
|
||||
def _forward_encoder_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@ -684,6 +822,7 @@ def use_cascade_attention(
|
||||
use_sliding_window: bool,
|
||||
use_local_attention: bool,
|
||||
num_sms: int,
|
||||
dcp_world_size: int,
|
||||
) -> bool:
|
||||
"""Decide whether to use cascade attention.
|
||||
|
||||
@ -705,6 +844,9 @@ def use_cascade_attention(
|
||||
num_reqs = len(query_lens)
|
||||
if num_reqs < 8:
|
||||
return False
|
||||
# disable cascade attention for DCP
|
||||
if dcp_world_size > 1:
|
||||
return False
|
||||
|
||||
# Heuristics to decide whether using cascade attention is beneficial.
|
||||
# 1. When FlashDecoding is not used for normal attention, cascade attention
|
||||
|
||||
@ -345,6 +345,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
use_sliding_window: bool,
|
||||
use_local_attention: bool,
|
||||
num_sms: int,
|
||||
dcp_world_size: int,
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@ -1523,6 +1523,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
use_sliding_window=use_sliding_window,
|
||||
use_local_attention=use_local_attention,
|
||||
num_sms=self.num_sms,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
)
|
||||
return common_prefix_len if use_cascade else 0
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user