[attention][DCP] use AttentionImpl.need_to_return_lse_for_decode (#24372)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-09-07 09:18:59 +08:00 committed by GitHub
parent 4172235ab7
commit 558f0907dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 38 additions and 9 deletions

View File

@ -257,6 +257,32 @@ class AttentionLayer(Protocol):
class AttentionImpl(ABC, Generic[T]):
# Whether the attention impl can return the softmax lse for decode.
# Some features like decode context parallelism require the softmax lse.
can_return_lse_for_decode: bool = False
# some attention backends might not always want to return lse
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False
dcp_world_size: int
dcp_rank: int
def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
try:
from vllm.distributed.parallel_state import get_dcp_group
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
and self.can_return_lse_for_decode
return self
@abstractmethod
def __init__(
self,

View File

@ -1592,10 +1592,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# recorect dcp attn_out with lse.
if self.dcp_world_size > 1:
assert lse is not None, (
"For a mla backend want to enable"
"DCP, it is mandatory that the corresponding decode attn"
"kernel return the softmax lse.")
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
# v_up projection

View File

@ -133,6 +133,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,

View File

@ -56,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
get_dtype_size, is_pin_memory_available, round_up,
supports_dynamo)
from vllm.v1.attention.backends.mla.flashmla import FlashMLABackend
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
create_fast_prefill_custom_backend,
@ -3405,10 +3404,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
copy_kv_blocks)
if self.dcp_world_size > 1:
assert self.attn_groups[0][0].backend is FlashMLABackend, (
"DCP only support flashmla now."
"For a mla backend want to enable DCP, it is mandatory that the"
"corresponding decode attn kernel return the softmax lse.")
layer_names = self.attn_groups[0][0].layer_names
layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase,
layer_names)
for layer in layers.values():
assert layer.impl.need_to_return_lse_for_decode, (
"DCP requires attention impls to return"
" the softmax lse for decode, but the impl "
f"{layer.impl.__class__.__name__} "
"does not return the softmax lse for decode.")
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""