[CI Failure] Fix backend selection for encoder-only models (#28534)

Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
Huamin Li 2025-11-13 07:11:27 -08:00 committed by GitHub
parent a7791eac9d
commit 07a606aa7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 75 additions and 6 deletions

View File

@ -142,6 +142,17 @@ class AttentionBackend(ABC):
def is_sparse(cls) -> bool: def is_sparse(cls) -> bool:
return False return False
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
from vllm.attention import AttentionType
return attn_type == AttentionType.DECODER
@classmethod @classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
return True return True
@ -171,6 +182,7 @@ class AttentionBackend(ABC):
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
device_capability: "DeviceCapability", device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]: ) -> list[str]:
invalid_reasons = [] invalid_reasons = []
if not cls.supports_head_size(head_size): if not cls.supports_head_size(head_size):
@ -195,6 +207,8 @@ class AttentionBackend(ABC):
invalid_reasons.append("non-sparse not supported") invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability): if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported") invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
invalid_reasons.append(f"attention type {attn_type} not supported")
combination_reason = cls.supports_combination( combination_reason = cls.supports_combination(
head_size, head_size,
dtype, dtype,

View File

@ -291,6 +291,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size, block_size,
use_mla=False, use_mla=False,
has_sink=self.has_sink, has_sink=self.has_sink,
attn_type=attn_type,
) )
else: else:
self.attn_backend = attn_backend self.attn_backend = attn_backend

View File

@ -74,7 +74,11 @@ class EncoderOnlyAttention(Attention):
block_size = 16 block_size = 16
underlying_attn_backend = get_attn_backend( underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_ONLY,
) )
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend) attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)

View File

@ -76,6 +76,7 @@ def get_attn_backend(
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False, use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
@ -94,6 +95,7 @@ def get_attn_backend(
use_mla=use_mla, use_mla=use_mla,
has_sink=has_sink, has_sink=has_sink,
use_sparse=use_sparse, use_sparse=use_sparse,
attn_type=attn_type,
) )
@ -106,6 +108,7 @@ def _cached_get_attn_backend(
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False, use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
# Check whether a particular choice of backend was # Check whether a particular choice of backend was
# previously forced. # previously forced.
@ -159,6 +162,7 @@ def _cached_get_attn_backend(
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
attn_type,
) )
else: else:
attention_cls = current_platform.get_attn_backend_cls( attention_cls = current_platform.get_attn_backend_cls(
@ -170,6 +174,7 @@ def _cached_get_attn_backend(
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
attn_type,
) )
if not attention_cls: if not attention_cls:
raise ValueError( raise ValueError(

View File

@ -134,6 +134,7 @@ class CpuPlatform(Platform):
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum

View File

@ -298,6 +298,7 @@ class CudaPlatformBase(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
device_capability, device_capability,
attn_type,
) -> tuple[ ) -> tuple[
list[tuple["AttentionBackendEnum", int]], list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]], dict["AttentionBackendEnum", list[str]],
@ -318,6 +319,7 @@ class CudaPlatformBase(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
device_capability, device_capability,
attn_type,
) )
except ImportError: except ImportError:
invalid_reasons_i = ["ImportError"] invalid_reasons_i = ["ImportError"]
@ -339,7 +341,13 @@ class CudaPlatformBase(Platform):
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention import AttentionType
if attn_type is None:
attn_type = AttentionType.DECODER
device_capability = cls.get_device_capability() device_capability = cls.get_device_capability()
assert device_capability is not None assert device_capability is not None
@ -356,6 +364,7 @@ class CudaPlatformBase(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
device_capability, device_capability,
attn_type,
) )
except ImportError: except ImportError:
invalid_reasons = ["ImportError"] invalid_reasons = ["ImportError"]
@ -379,6 +388,7 @@ class CudaPlatformBase(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
device_capability, device_capability,
attn_type,
) )
reasons_str = ( reasons_str = (
"{" "{"

View File

@ -222,6 +222,7 @@ class Platform:
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None,
) -> str: ) -> str:
"""Get the attention backend class of a device.""" """Get the attention backend class of a device."""
return "" return ""

View File

@ -216,6 +216,7 @@ class RocmPlatform(Platform):
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum

View File

@ -61,6 +61,7 @@ class TpuPlatform(Platform):
use_mla: bool, use_mla: bool,
has_sink, has_sink,
use_sparse, use_sparse,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum

View File

@ -51,6 +51,7 @@ class XPUPlatform(Platform):
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse, use_sparse,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout from vllm.v1.attention.backends.utils import set_kv_cache_layout

View File

@ -48,6 +48,17 @@ class CPUAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "CPU_ATTN" return "CPU_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention."""
from vllm.attention import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
)
@staticmethod @staticmethod
def get_impl_cls() -> type["CPUAttentionBackendImpl"]: def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
return CPUAttentionBackendImpl return CPUAttentionBackendImpl

View File

@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN" return "FLASH_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types."""
from vllm.attention import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod @staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]: def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl return FlashAttentionImpl

View File

@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "FLEX_ATTENTION" return "FLEX_ATTENTION"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention."""
from vllm.attention import AttentionType
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@staticmethod @staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]: def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl return FlexAttentionImpl