[CI Failure] Fix backend selection for encoder-only models (#28534)

Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
Huamin Li 2025-11-13 07:11:27 -08:00 committed by GitHub
parent a7791eac9d
commit 07a606aa7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 75 additions and 6 deletions

View File

@ -142,6 +142,17 @@ class AttentionBackend(ABC):
def is_sparse(cls) -> bool:
return False
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
from vllm.attention import AttentionType
return attn_type == AttentionType.DECODER
@classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
return True
@ -171,6 +182,7 @@ class AttentionBackend(ABC):
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
@ -195,6 +207,8 @@ class AttentionBackend(ABC):
invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
invalid_reasons.append(f"attention type {attn_type} not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,

View File

@ -291,6 +291,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size,
use_mla=False,
has_sink=self.has_sink,
attn_type=attn_type,
)
else:
self.attn_backend = attn_backend

View File

@ -74,7 +74,11 @@ class EncoderOnlyAttention(Attention):
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_ONLY,
)
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)

View File

@ -76,6 +76,7 @@ def get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
@ -94,6 +95,7 @@ def get_attn_backend(
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
attn_type=attn_type,
)
@ -106,6 +108,7 @@ def _cached_get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
@ -159,6 +162,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
attn_type,
)
else:
attention_cls = current_platform.get_attn_backend_cls(
@ -170,6 +174,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
attn_type,
)
if not attention_cls:
raise ValueError(

View File

@ -134,6 +134,7 @@ class CpuPlatform(Platform):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum

View File

@ -298,6 +298,7 @@ class CudaPlatformBase(Platform):
has_sink,
use_sparse,
device_capability,
attn_type,
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
@ -318,6 +319,7 @@ class CudaPlatformBase(Platform):
has_sink,
use_sparse,
device_capability,
attn_type,
)
except ImportError:
invalid_reasons_i = ["ImportError"]
@ -339,7 +341,13 @@ class CudaPlatformBase(Platform):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention import AttentionType
if attn_type is None:
attn_type = AttentionType.DECODER
device_capability = cls.get_device_capability()
assert device_capability is not None
@ -356,6 +364,7 @@ class CudaPlatformBase(Platform):
has_sink,
use_sparse,
device_capability,
attn_type,
)
except ImportError:
invalid_reasons = ["ImportError"]
@ -379,6 +388,7 @@ class CudaPlatformBase(Platform):
has_sink,
use_sparse,
device_capability,
attn_type,
)
reasons_str = (
"{"

View File

@ -222,6 +222,7 @@ class Platform:
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
"""Get the attention backend class of a device."""
return ""

View File

@ -216,6 +216,7 @@ class RocmPlatform(Platform):
use_mla,
has_sink,
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum

View File

@ -61,6 +61,7 @@ class TpuPlatform(Platform):
use_mla: bool,
has_sink,
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum

View File

@ -51,6 +51,7 @@ class XPUPlatform(Platform):
use_mla: bool,
has_sink: bool,
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout

View File

@ -48,6 +48,17 @@ class CPUAttentionBackend(AttentionBackend):
def get_name() -> str:
return "CPU_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention."""
from vllm.attention import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
)
@staticmethod
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
return CPUAttentionBackendImpl

View File

@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend):
def get_name() -> str:
return "FLASH_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types."""
from vllm.attention import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl

View File

@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend):
def get_name() -> str:
return "FLEX_ATTENTION"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention."""
from vllm.attention import AttentionType
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl

View File

@ -40,14 +40,14 @@ logger = init_logger(__name__)
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
structured as:
- **First 512 bytes:** The "quantized NoPE" part, containing 512
- **First 512 bytes:** The "quantized NoPE" part, containing 512
`float8_e4m3` values.
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
The first `float32` is the scale for the first 128 `float8_e4m3` values,
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
The first `float32` is the scale for the first 128 `float8_e4m3` values,
the second for the next 128, and so on.
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
part is not quantized for accuracy.
"""