mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 05:58:42 +08:00
[CI Failure] Fix backend selection for encoder-only models (#28534)
Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
parent
a7791eac9d
commit
07a606aa7e
@ -142,6 +142,17 @@ class AttentionBackend(ABC):
|
||||
def is_sparse(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""Check if backend supports a given attention type.
|
||||
|
||||
By default, only supports decoder attention.
|
||||
Backends should override this to support other attention types.
|
||||
"""
|
||||
from vllm.attention import AttentionType
|
||||
|
||||
return attn_type == AttentionType.DECODER
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
|
||||
return True
|
||||
@ -171,6 +182,7 @@ class AttentionBackend(ABC):
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: "DeviceCapability",
|
||||
attn_type: str,
|
||||
) -> list[str]:
|
||||
invalid_reasons = []
|
||||
if not cls.supports_head_size(head_size):
|
||||
@ -195,6 +207,8 @@ class AttentionBackend(ABC):
|
||||
invalid_reasons.append("non-sparse not supported")
|
||||
if not cls.supports_compute_capability(device_capability):
|
||||
invalid_reasons.append("compute capability not supported")
|
||||
if not cls.supports_attn_type(attn_type):
|
||||
invalid_reasons.append(f"attention type {attn_type} not supported")
|
||||
combination_reason = cls.supports_combination(
|
||||
head_size,
|
||||
dtype,
|
||||
|
||||
@ -291,6 +291,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
block_size,
|
||||
use_mla=False,
|
||||
has_sink=self.has_sink,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
@ -74,7 +74,11 @@ class EncoderOnlyAttention(Attention):
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
)
|
||||
|
||||
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
|
||||
|
||||
@ -76,6 +76,7 @@ def get_attn_backend(
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
attn_type: str | None = None,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
|
||||
@ -94,6 +95,7 @@ def get_attn_backend(
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
use_sparse=use_sparse,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
|
||||
|
||||
@ -106,6 +108,7 @@ def _cached_get_attn_backend(
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
attn_type: str | None = None,
|
||||
) -> type[AttentionBackend]:
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
@ -159,6 +162,7 @@ def _cached_get_attn_backend(
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
attn_type,
|
||||
)
|
||||
else:
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
@ -170,6 +174,7 @@ def _cached_get_attn_backend(
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
attn_type,
|
||||
)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
|
||||
@ -134,6 +134,7 @@ class CpuPlatform(Platform):
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
|
||||
@ -298,6 +298,7 @@ class CudaPlatformBase(Platform):
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
attn_type,
|
||||
) -> tuple[
|
||||
list[tuple["AttentionBackendEnum", int]],
|
||||
dict["AttentionBackendEnum", list[str]],
|
||||
@ -318,6 +319,7 @@ class CudaPlatformBase(Platform):
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
attn_type,
|
||||
)
|
||||
except ImportError:
|
||||
invalid_reasons_i = ["ImportError"]
|
||||
@ -339,7 +341,13 @@ class CudaPlatformBase(Platform):
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm.attention import AttentionType
|
||||
|
||||
if attn_type is None:
|
||||
attn_type = AttentionType.DECODER
|
||||
|
||||
device_capability = cls.get_device_capability()
|
||||
assert device_capability is not None
|
||||
|
||||
@ -356,6 +364,7 @@ class CudaPlatformBase(Platform):
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
attn_type,
|
||||
)
|
||||
except ImportError:
|
||||
invalid_reasons = ["ImportError"]
|
||||
@ -379,6 +388,7 @@ class CudaPlatformBase(Platform):
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
attn_type,
|
||||
)
|
||||
reasons_str = (
|
||||
"{"
|
||||
|
||||
@ -222,6 +222,7 @@ class Platform:
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
return ""
|
||||
|
||||
@ -216,6 +216,7 @@ class RocmPlatform(Platform):
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
@ -61,6 +61,7 @@ class TpuPlatform(Platform):
|
||||
use_mla: bool,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
|
||||
@ -51,6 +51,7 @@ class XPUPlatform(Platform):
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
|
||||
@ -48,6 +48,17 @@ class CPUAttentionBackend(AttentionBackend):
|
||||
def get_name() -> str:
|
||||
return "CPU_ATTN"
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""CPU attention supports decoder and encoder-only attention."""
|
||||
from vllm.attention import AttentionType
|
||||
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
|
||||
return CPUAttentionBackendImpl
|
||||
|
||||
@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN"
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""FlashAttention supports all attention types."""
|
||||
from vllm.attention import AttentionType
|
||||
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
||||
return FlashAttentionImpl
|
||||
|
||||
@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend):
|
||||
def get_name() -> str:
|
||||
return "FLEX_ATTENTION"
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""FlexAttention supports both decoder and encoder-only attention."""
|
||||
from vllm.attention import AttentionType
|
||||
|
||||
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlexAttentionImpl"]:
|
||||
return FlexAttentionImpl
|
||||
|
||||
@ -40,14 +40,14 @@ logger = init_logger(__name__)
|
||||
"""
|
||||
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
|
||||
|
||||
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
|
||||
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
|
||||
structured as:
|
||||
- **First 512 bytes:** The "quantized NoPE" part, containing 512
|
||||
- **First 512 bytes:** The "quantized NoPE" part, containing 512
|
||||
`float8_e4m3` values.
|
||||
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
|
||||
The first `float32` is the scale for the first 128 `float8_e4m3` values,
|
||||
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
|
||||
The first `float32` is the scale for the first 128 `float8_e4m3` values,
|
||||
the second for the next 128, and so on.
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
part is not quantized for accuracy.
|
||||
"""
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user