mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 11:57:14 +08:00
[gpt-oss] Enable gpt-oss on ampere (#22714)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
parent
b8a9d0e429
commit
007dd90859
@ -25,5 +25,6 @@ class DummyPlatform(Platform):
|
||||
compilation_config.custom_ops = ["all"]
|
||||
|
||||
def get_attn_backend_cls(self, backend_name, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla):
|
||||
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink):
|
||||
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||
|
||||
@ -138,6 +138,7 @@ class Attention(nn.Module):
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
@ -165,7 +166,8 @@ class Attention(nn.Module):
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
use_mla=use_mla)
|
||||
use_mla=use_mla,
|
||||
has_sink=self.has_sink)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
|
||||
@ -144,6 +144,7 @@ def get_attn_backend(
|
||||
block_size: int,
|
||||
is_attention_free: bool = False,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
@ -158,6 +159,7 @@ def get_attn_backend(
|
||||
is_attention_free=is_attention_free,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
)
|
||||
|
||||
|
||||
@ -170,6 +172,7 @@ def _cached_get_attn_backend(
|
||||
is_attention_free: bool,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
@ -201,7 +204,7 @@ def _cached_get_attn_backend(
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
|
||||
use_mla)
|
||||
use_mla, has_sink)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}")
|
||||
|
||||
@ -42,7 +42,7 @@ class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
|
||||
@ -91,8 +91,8 @@ class CpuPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool,
|
||||
use_mla: bool) -> str:
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool) -> str:
|
||||
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
|
||||
@ -222,8 +222,8 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1,
|
||||
use_mla) -> str:
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink) -> str:
|
||||
if use_mla:
|
||||
# TODO(lucas): refactor to be more concise
|
||||
# we should probably consider factoring out V1 here
|
||||
@ -321,6 +321,9 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
# FlashAttention is the default for SM 8.0+ GPUs
|
||||
if cls.has_device_capability(80):
|
||||
if has_sink:
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return TRITON_ATTN_VLLM_V1
|
||||
if is_default_backend_supported := is_attn_backend_supported(
|
||||
FLASH_ATTN_V1, head_size, dtype,
|
||||
allow_import_error=False):
|
||||
|
||||
@ -196,8 +196,8 @@ class Platform:
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool,
|
||||
use_mla: bool) -> str:
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
return ""
|
||||
|
||||
|
||||
@ -188,8 +188,8 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1,
|
||||
use_mla) -> str:
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink) -> str:
|
||||
if use_mla:
|
||||
from vllm.attention.backends.rocm_aiter_mla import (
|
||||
is_aiter_mla_enabled)
|
||||
|
||||
@ -46,8 +46,8 @@ class TpuPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool,
|
||||
use_mla: bool) -> str:
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink) -> str:
|
||||
if (selected_backend != _Backend.PALLAS
|
||||
and selected_backend != _Backend.PALLAS_VLLM_V1):
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
|
||||
@ -35,8 +35,8 @@ class XPUPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool,
|
||||
use_mla: bool) -> str:
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool) -> str:
|
||||
if selected_backend is not None and selected_backend != _Backend.IPEX:
|
||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||
use_v1 = envs.VLLM_USE_V1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user