[gpt-oss] Enable gpt-oss on ampere (#22714)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Yongye Zhu 2025-08-12 06:21:44 -04:00 committed by GitHub
parent b8a9d0e429
commit 007dd90859
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 26 additions and 17 deletions

View File

@ -25,5 +25,6 @@ class DummyPlatform(Platform):
compilation_config.custom_ops = ["all"]
def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

View File

@ -138,6 +138,7 @@ class Attention(nn.Module):
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
@ -165,7 +166,8 @@ class Attention(nn.Module):
kv_cache_dtype,
block_size,
is_attention_free,
use_mla=use_mla)
use_mla=use_mla,
has_sink=self.has_sink)
else:
self.attn_backend = attn_backend

View File

@ -144,6 +144,7 @@ def get_attn_backend(
block_size: int,
is_attention_free: bool = False,
use_mla: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
@ -158,6 +159,7 @@ def get_attn_backend(
is_attention_free=is_attention_free,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
has_sink=has_sink,
)
@ -170,6 +172,7 @@ def _cached_get_attn_backend(
is_attention_free: bool,
use_v1: bool = False,
use_mla: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
@ -201,7 +204,7 @@ def _cached_get_attn_backend(
# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
use_mla)
use_mla, has_sink)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}")

View File

@ -42,7 +42,7 @@ class Mxfp4Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 90
return 80
@classmethod
def get_name(cls) -> QuantizationMethods:

View File

@ -91,8 +91,8 @@ class CpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool,
use_mla: bool) -> str:
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool) -> str:
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:

View File

@ -222,8 +222,8 @@ class CudaPlatformBase(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str:
if use_mla:
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
@ -321,6 +321,9 @@ class CudaPlatformBase(Platform):
# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80):
if has_sink:
logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1
if is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype,
allow_import_error=False):

View File

@ -196,8 +196,8 @@ class Platform:
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool,
use_mla: bool) -> str:
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool) -> str:
"""Get the attention backend class of a device."""
return ""

View File

@ -188,8 +188,8 @@ class RocmPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str:
if use_mla:
from vllm.attention.backends.rocm_aiter_mla import (
is_aiter_mla_enabled)

View File

@ -46,8 +46,8 @@ class TpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool,
use_mla: bool) -> str:
block_size: int, use_v1: bool, use_mla: bool,
has_sink) -> str:
if (selected_backend != _Backend.PALLAS
and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend)

View File

@ -35,8 +35,8 @@ class XPUPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool,
use_mla: bool) -> str:
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool) -> str:
if selected_backend is not None and selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
use_v1 = envs.VLLM_USE_V1