mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 22:17:54 +08:00
[CI Failure] Fix backend selection for encoder-only models (#28534)
Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
parent
a7791eac9d
commit
07a606aa7e
@ -142,6 +142,17 @@ class AttentionBackend(ABC):
|
|||||||
def is_sparse(cls) -> bool:
|
def is_sparse(cls) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||||
|
"""Check if backend supports a given attention type.
|
||||||
|
|
||||||
|
By default, only supports decoder attention.
|
||||||
|
Backends should override this to support other attention types.
|
||||||
|
"""
|
||||||
|
from vllm.attention import AttentionType
|
||||||
|
|
||||||
|
return attn_type == AttentionType.DECODER
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
|
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
|
||||||
return True
|
return True
|
||||||
@ -171,6 +182,7 @@ class AttentionBackend(ABC):
|
|||||||
has_sink: bool,
|
has_sink: bool,
|
||||||
use_sparse: bool,
|
use_sparse: bool,
|
||||||
device_capability: "DeviceCapability",
|
device_capability: "DeviceCapability",
|
||||||
|
attn_type: str,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
invalid_reasons = []
|
invalid_reasons = []
|
||||||
if not cls.supports_head_size(head_size):
|
if not cls.supports_head_size(head_size):
|
||||||
@ -195,6 +207,8 @@ class AttentionBackend(ABC):
|
|||||||
invalid_reasons.append("non-sparse not supported")
|
invalid_reasons.append("non-sparse not supported")
|
||||||
if not cls.supports_compute_capability(device_capability):
|
if not cls.supports_compute_capability(device_capability):
|
||||||
invalid_reasons.append("compute capability not supported")
|
invalid_reasons.append("compute capability not supported")
|
||||||
|
if not cls.supports_attn_type(attn_type):
|
||||||
|
invalid_reasons.append(f"attention type {attn_type} not supported")
|
||||||
combination_reason = cls.supports_combination(
|
combination_reason = cls.supports_combination(
|
||||||
head_size,
|
head_size,
|
||||||
dtype,
|
dtype,
|
||||||
|
|||||||
@ -291,6 +291,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
block_size,
|
block_size,
|
||||||
use_mla=False,
|
use_mla=False,
|
||||||
has_sink=self.has_sink,
|
has_sink=self.has_sink,
|
||||||
|
attn_type=attn_type,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.attn_backend = attn_backend
|
self.attn_backend = attn_backend
|
||||||
|
|||||||
@ -74,7 +74,11 @@ class EncoderOnlyAttention(Attention):
|
|||||||
block_size = 16
|
block_size = 16
|
||||||
|
|
||||||
underlying_attn_backend = get_attn_backend(
|
underlying_attn_backend = get_attn_backend(
|
||||||
head_size, dtype, kv_cache_dtype, block_size
|
head_size,
|
||||||
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size,
|
||||||
|
attn_type=AttentionType.ENCODER_ONLY,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
|
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
|
||||||
|
|||||||
@ -76,6 +76,7 @@ def get_attn_backend(
|
|||||||
use_mla: bool = False,
|
use_mla: bool = False,
|
||||||
has_sink: bool = False,
|
has_sink: bool = False,
|
||||||
use_sparse: bool = False,
|
use_sparse: bool = False,
|
||||||
|
attn_type: str | None = None,
|
||||||
) -> type[AttentionBackend]:
|
) -> type[AttentionBackend]:
|
||||||
"""Selects which attention backend to use and lazily imports it."""
|
"""Selects which attention backend to use and lazily imports it."""
|
||||||
|
|
||||||
@ -94,6 +95,7 @@ def get_attn_backend(
|
|||||||
use_mla=use_mla,
|
use_mla=use_mla,
|
||||||
has_sink=has_sink,
|
has_sink=has_sink,
|
||||||
use_sparse=use_sparse,
|
use_sparse=use_sparse,
|
||||||
|
attn_type=attn_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -106,6 +108,7 @@ def _cached_get_attn_backend(
|
|||||||
use_mla: bool = False,
|
use_mla: bool = False,
|
||||||
has_sink: bool = False,
|
has_sink: bool = False,
|
||||||
use_sparse: bool = False,
|
use_sparse: bool = False,
|
||||||
|
attn_type: str | None = None,
|
||||||
) -> type[AttentionBackend]:
|
) -> type[AttentionBackend]:
|
||||||
# Check whether a particular choice of backend was
|
# Check whether a particular choice of backend was
|
||||||
# previously forced.
|
# previously forced.
|
||||||
@ -159,6 +162,7 @@ def _cached_get_attn_backend(
|
|||||||
use_mla,
|
use_mla,
|
||||||
has_sink,
|
has_sink,
|
||||||
use_sparse,
|
use_sparse,
|
||||||
|
attn_type,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_cls = current_platform.get_attn_backend_cls(
|
attention_cls = current_platform.get_attn_backend_cls(
|
||||||
@ -170,6 +174,7 @@ def _cached_get_attn_backend(
|
|||||||
use_mla,
|
use_mla,
|
||||||
has_sink,
|
has_sink,
|
||||||
use_sparse,
|
use_sparse,
|
||||||
|
attn_type,
|
||||||
)
|
)
|
||||||
if not attention_cls:
|
if not attention_cls:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@ -134,6 +134,7 @@ class CpuPlatform(Platform):
|
|||||||
use_mla: bool,
|
use_mla: bool,
|
||||||
has_sink: bool,
|
has_sink: bool,
|
||||||
use_sparse: bool,
|
use_sparse: bool,
|
||||||
|
attn_type: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
|
|
||||||
|
|||||||
@ -298,6 +298,7 @@ class CudaPlatformBase(Platform):
|
|||||||
has_sink,
|
has_sink,
|
||||||
use_sparse,
|
use_sparse,
|
||||||
device_capability,
|
device_capability,
|
||||||
|
attn_type,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
list[tuple["AttentionBackendEnum", int]],
|
list[tuple["AttentionBackendEnum", int]],
|
||||||
dict["AttentionBackendEnum", list[str]],
|
dict["AttentionBackendEnum", list[str]],
|
||||||
@ -318,6 +319,7 @@ class CudaPlatformBase(Platform):
|
|||||||
has_sink,
|
has_sink,
|
||||||
use_sparse,
|
use_sparse,
|
||||||
device_capability,
|
device_capability,
|
||||||
|
attn_type,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
invalid_reasons_i = ["ImportError"]
|
invalid_reasons_i = ["ImportError"]
|
||||||
@ -339,7 +341,13 @@ class CudaPlatformBase(Platform):
|
|||||||
use_mla: bool,
|
use_mla: bool,
|
||||||
has_sink: bool,
|
has_sink: bool,
|
||||||
use_sparse: bool,
|
use_sparse: bool,
|
||||||
|
attn_type: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
from vllm.attention import AttentionType
|
||||||
|
|
||||||
|
if attn_type is None:
|
||||||
|
attn_type = AttentionType.DECODER
|
||||||
|
|
||||||
device_capability = cls.get_device_capability()
|
device_capability = cls.get_device_capability()
|
||||||
assert device_capability is not None
|
assert device_capability is not None
|
||||||
|
|
||||||
@ -356,6 +364,7 @@ class CudaPlatformBase(Platform):
|
|||||||
has_sink,
|
has_sink,
|
||||||
use_sparse,
|
use_sparse,
|
||||||
device_capability,
|
device_capability,
|
||||||
|
attn_type,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
invalid_reasons = ["ImportError"]
|
invalid_reasons = ["ImportError"]
|
||||||
@ -379,6 +388,7 @@ class CudaPlatformBase(Platform):
|
|||||||
has_sink,
|
has_sink,
|
||||||
use_sparse,
|
use_sparse,
|
||||||
device_capability,
|
device_capability,
|
||||||
|
attn_type,
|
||||||
)
|
)
|
||||||
reasons_str = (
|
reasons_str = (
|
||||||
"{"
|
"{"
|
||||||
|
|||||||
@ -222,6 +222,7 @@ class Platform:
|
|||||||
use_mla: bool,
|
use_mla: bool,
|
||||||
has_sink: bool,
|
has_sink: bool,
|
||||||
use_sparse: bool,
|
use_sparse: bool,
|
||||||
|
attn_type: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get the attention backend class of a device."""
|
"""Get the attention backend class of a device."""
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@ -216,6 +216,7 @@ class RocmPlatform(Platform):
|
|||||||
use_mla,
|
use_mla,
|
||||||
has_sink,
|
has_sink,
|
||||||
use_sparse,
|
use_sparse,
|
||||||
|
attn_type: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
|
|||||||
@ -61,6 +61,7 @@ class TpuPlatform(Platform):
|
|||||||
use_mla: bool,
|
use_mla: bool,
|
||||||
has_sink,
|
has_sink,
|
||||||
use_sparse,
|
use_sparse,
|
||||||
|
attn_type: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
|
|
||||||
|
|||||||
@ -51,6 +51,7 @@ class XPUPlatform(Platform):
|
|||||||
use_mla: bool,
|
use_mla: bool,
|
||||||
has_sink: bool,
|
has_sink: bool,
|
||||||
use_sparse,
|
use_sparse,
|
||||||
|
attn_type: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||||
|
|
||||||
|
|||||||
@ -48,6 +48,17 @@ class CPUAttentionBackend(AttentionBackend):
|
|||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "CPU_ATTN"
|
return "CPU_ATTN"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||||
|
"""CPU attention supports decoder and encoder-only attention."""
|
||||||
|
from vllm.attention import AttentionType
|
||||||
|
|
||||||
|
return attn_type in (
|
||||||
|
AttentionType.DECODER,
|
||||||
|
AttentionType.ENCODER,
|
||||||
|
AttentionType.ENCODER_ONLY,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
|
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
|
||||||
return CPUAttentionBackendImpl
|
return CPUAttentionBackendImpl
|
||||||
|
|||||||
@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASH_ATTN"
|
return "FLASH_ATTN"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||||
|
"""FlashAttention supports all attention types."""
|
||||||
|
from vllm.attention import AttentionType
|
||||||
|
|
||||||
|
return attn_type in (
|
||||||
|
AttentionType.DECODER,
|
||||||
|
AttentionType.ENCODER,
|
||||||
|
AttentionType.ENCODER_ONLY,
|
||||||
|
AttentionType.ENCODER_DECODER,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
||||||
return FlashAttentionImpl
|
return FlashAttentionImpl
|
||||||
|
|||||||
@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend):
|
|||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLEX_ATTENTION"
|
return "FLEX_ATTENTION"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||||
|
"""FlexAttention supports both decoder and encoder-only attention."""
|
||||||
|
from vllm.attention import AttentionType
|
||||||
|
|
||||||
|
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["FlexAttentionImpl"]:
|
def get_impl_cls() -> type["FlexAttentionImpl"]:
|
||||||
return FlexAttentionImpl
|
return FlexAttentionImpl
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user