From 558f0907dc67c804c5821c92f4d64ed43d10489b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 7 Sep 2025 09:18:59 +0800 Subject: [PATCH] [attention][DCP] use AttentionImpl.need_to_return_lse_for_decode (#24372) Signed-off-by: youkaichao --- vllm/attention/backends/abstract.py | 26 ++++++++++++++++++++++ vllm/v1/attention/backends/mla/common.py | 4 ---- vllm/v1/attention/backends/mla/flashmla.py | 2 ++ vllm/v1/worker/gpu_model_runner.py | 15 ++++++++----- 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0b9c625533cb7..0217bff6adafa 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -257,6 +257,32 @@ class AttentionLayer(Protocol): class AttentionImpl(ABC, Generic[T]): + # Whether the attention impl can return the softmax lse for decode. + # Some features like decode context parallelism require the softmax lse. + can_return_lse_for_decode: bool = False + + # some attention backends might not always want to return lse + # even if they can return lse (for efficiency reasons) + need_to_return_lse_for_decode: bool = False + + dcp_world_size: int + dcp_rank: int + + def __new__(cls, *args, **kwargs): + # use __new__ so that all subclasses will call this + self = super().__new__(cls) + try: + from vllm.distributed.parallel_state import get_dcp_group + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \ + and self.can_return_lse_for_decode + return self + @abstractmethod def __init__( self, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 090ebf93840d8..ec1216a16bc46 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1592,10 +1592,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # recorect dcp attn_out with lse. if self.dcp_world_size > 1: - assert lse is not None, ( - "For a mla backend want to enable" - "DCP, it is mandatory that the corresponding decode attn" - "kernel return the softmax lse.") attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) # v_up projection diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 11c91b8a0650e..1824bbadb6a1a 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -133,6 +133,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + can_return_lse_for_decode: bool = True + def __init__( self, num_heads: int, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 563872f8d68f3..0224f3944f005 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -56,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up, supports_dynamo) -from vllm.v1.attention.backends.mla.flashmla import FlashMLABackend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, @@ -3405,10 +3404,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): copy_kv_blocks) if self.dcp_world_size > 1: - assert self.attn_groups[0][0].backend is FlashMLABackend, ( - "DCP only support flashmla now." - "For a mla backend want to enable DCP, it is mandatory that the" - "corresponding decode attn kernel return the softmax lse.") + layer_names = self.attn_groups[0][0].layer_names + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) + for layer in layers.values(): + assert layer.impl.need_to_return_lse_for_decode, ( + "DCP requires attention impls to return" + " the softmax lse for decode, but the impl " + f"{layer.impl.__class__.__name__} " + "does not return the softmax lse for decode.") def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """