mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-30 17:47:08 +08:00
[attention][DCP] use AttentionImpl.need_to_return_lse_for_decode (#24372)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
4172235ab7
commit
558f0907dc
@ -257,6 +257,32 @@ class AttentionLayer(Protocol):
|
|||||||
|
|
||||||
class AttentionImpl(ABC, Generic[T]):
|
class AttentionImpl(ABC, Generic[T]):
|
||||||
|
|
||||||
|
# Whether the attention impl can return the softmax lse for decode.
|
||||||
|
# Some features like decode context parallelism require the softmax lse.
|
||||||
|
can_return_lse_for_decode: bool = False
|
||||||
|
|
||||||
|
# some attention backends might not always want to return lse
|
||||||
|
# even if they can return lse (for efficiency reasons)
|
||||||
|
need_to_return_lse_for_decode: bool = False
|
||||||
|
|
||||||
|
dcp_world_size: int
|
||||||
|
dcp_rank: int
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
# use __new__ so that all subclasses will call this
|
||||||
|
self = super().__new__(cls)
|
||||||
|
try:
|
||||||
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
|
self.dcp_rank = get_dcp_group().rank_in_group
|
||||||
|
except AssertionError:
|
||||||
|
# DCP might not be initialized in testing
|
||||||
|
self.dcp_world_size = 1
|
||||||
|
self.dcp_rank = 0
|
||||||
|
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
|
||||||
|
and self.can_return_lse_for_decode
|
||||||
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1592,10 +1592,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
|
|
||||||
# recorect dcp attn_out with lse.
|
# recorect dcp attn_out with lse.
|
||||||
if self.dcp_world_size > 1:
|
if self.dcp_world_size > 1:
|
||||||
assert lse is not None, (
|
|
||||||
"For a mla backend want to enable"
|
|
||||||
"DCP, it is mandatory that the corresponding decode attn"
|
|
||||||
"kernel return the softmax lse.")
|
|
||||||
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
||||||
|
|
||||||
# v_up projection
|
# v_up projection
|
||||||
|
|||||||
@ -133,6 +133,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
|
|
||||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||||
|
|
||||||
|
can_return_lse_for_decode: bool = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
|
|||||||
@ -56,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|||||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||||
get_dtype_size, is_pin_memory_available, round_up,
|
get_dtype_size, is_pin_memory_available, round_up,
|
||||||
supports_dynamo)
|
supports_dynamo)
|
||||||
from vllm.v1.attention.backends.mla.flashmla import FlashMLABackend
|
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
create_fast_prefill_custom_backend,
|
create_fast_prefill_custom_backend,
|
||||||
@ -3405,10 +3404,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
copy_kv_blocks)
|
copy_kv_blocks)
|
||||||
|
|
||||||
if self.dcp_world_size > 1:
|
if self.dcp_world_size > 1:
|
||||||
assert self.attn_groups[0][0].backend is FlashMLABackend, (
|
layer_names = self.attn_groups[0][0].layer_names
|
||||||
"DCP only support flashmla now."
|
layers = get_layers_from_vllm_config(self.vllm_config,
|
||||||
"For a mla backend want to enable DCP, it is mandatory that the"
|
AttentionLayerBase,
|
||||||
"corresponding decode attn kernel return the softmax lse.")
|
layer_names)
|
||||||
|
for layer in layers.values():
|
||||||
|
assert layer.impl.need_to_return_lse_for_decode, (
|
||||||
|
"DCP requires attention impls to return"
|
||||||
|
" the softmax lse for decode, but the impl "
|
||||||
|
f"{layer.impl.__class__.__name__} "
|
||||||
|
"does not return the softmax lse for decode.")
|
||||||
|
|
||||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user