From 007dd90859cc0337510536677418a43d8f66e286 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Tue, 12 Aug 2025 06:21:44 -0400 Subject: [PATCH] [gpt-oss] Enable gpt-oss on ampere (#22714) Signed-off-by: Yongye Zhu --- .../vllm_add_dummy_platform/dummy_platform.py | 5 +++-- vllm/attention/layer.py | 4 +++- vllm/attention/selector.py | 5 ++++- vllm/model_executor/layers/quantization/mxfp4.py | 2 +- vllm/platforms/cpu.py | 4 ++-- vllm/platforms/cuda.py | 7 +++++-- vllm/platforms/interface.py | 4 ++-- vllm/platforms/rocm.py | 4 ++-- vllm/platforms/tpu.py | 4 ++-- vllm/platforms/xpu.py | 4 ++-- 10 files changed, 26 insertions(+), 17 deletions(-) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py index e67825f89d815..8d0687b49bb47 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -25,5 +25,6 @@ class DummyPlatform(Platform): compilation_config.custom_ops = ["all"] def get_attn_backend_cls(self, backend_name, head_size, dtype, - kv_cache_dtype, block_size, use_v1, use_mla): - return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 \ No newline at end of file + kv_cache_dtype, block_size, use_v1, use_mla, + has_sink): + return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b4c3cbd7c9d64..1a9c0e26b53ca 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -138,6 +138,7 @@ class Attention(nn.Module): self.head_size = head_size self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window + self.has_sink = extra_impl_args.get("sinks") is not None quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None @@ -165,7 +166,8 @@ class Attention(nn.Module): kv_cache_dtype, block_size, is_attention_free, - use_mla=use_mla) + use_mla=use_mla, + has_sink=self.has_sink) else: self.attn_backend = attn_backend diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 508470bb363e1..3a235ba6e0b42 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -144,6 +144,7 @@ def get_attn_backend( block_size: int, is_attention_free: bool = False, use_mla: bool = False, + has_sink: bool = False, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong @@ -158,6 +159,7 @@ def get_attn_backend( is_attention_free=is_attention_free, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, + has_sink=has_sink, ) @@ -170,6 +172,7 @@ def _cached_get_attn_backend( is_attention_free: bool, use_v1: bool = False, use_mla: bool = False, + has_sink: bool = False, ) -> type[AttentionBackend]: # If there are no attention layers (e.g. we are running Mamba), # use the placeholder NO_ATTENTION @@ -201,7 +204,7 @@ def _cached_get_attn_backend( # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls( selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, - use_mla) + use_mla, has_sink) if not attention_cls: raise ValueError( f"Invalid attention backend for {current_platform.device_name}") diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 4e59aef480fde..03fbcf158338e 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -42,7 +42,7 @@ class Mxfp4Config(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - return 90 + return 80 @classmethod def get_name(cls) -> QuantizationMethods: diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 31a67183ff12c..0b16a8e1d1d8b 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -91,8 +91,8 @@ class CpuPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, - use_mla: bool) -> str: + block_size: int, use_v1: bool, use_mla: bool, + has_sink: bool) -> str: if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index dd9356e399c9d..c876c52a2e9c9 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -222,8 +222,8 @@ class CudaPlatformBase(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1, - use_mla) -> str: + kv_cache_dtype, block_size, use_v1, use_mla, + has_sink) -> str: if use_mla: # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here @@ -321,6 +321,9 @@ class CudaPlatformBase(Platform): # FlashAttention is the default for SM 8.0+ GPUs if cls.has_device_capability(80): + if has_sink: + logger.info_once("Using Triton backend on V1 engine.") + return TRITON_ATTN_VLLM_V1 if is_default_backend_supported := is_attn_backend_supported( FLASH_ATTN_V1, head_size, dtype, allow_import_error=False): diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index a85b583abc2ce..91d5314900c87 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -196,8 +196,8 @@ class Platform: @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, - use_mla: bool) -> str: + block_size: int, use_v1: bool, use_mla: bool, + has_sink: bool) -> str: """Get the attention backend class of a device.""" return "" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d26e4b3350381..8005830f55cef 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -188,8 +188,8 @@ class RocmPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1, - use_mla) -> str: + kv_cache_dtype, block_size, use_v1, use_mla, + has_sink) -> str: if use_mla: from vllm.attention.backends.rocm_aiter_mla import ( is_aiter_mla_enabled) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 146801c9d7739..c56096d93612d 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -46,8 +46,8 @@ class TpuPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, - use_mla: bool) -> str: + block_size: int, use_v1: bool, use_mla: bool, + has_sink) -> str: if (selected_backend != _Backend.PALLAS and selected_backend != _Backend.PALLAS_VLLM_V1): logger.info("Cannot use %s backend on TPU.", selected_backend) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index d8a663f2f0c4a..abd58dbbcbf45 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -35,8 +35,8 @@ class XPUPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, - use_mla: bool) -> str: + block_size: int, use_v1: bool, use_mla: bool, + has_sink: bool) -> str: if selected_backend is not None and selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) use_v1 = envs.VLLM_USE_V1