From f6027b28553da470d9821a25a0addf9472c6834a Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 22 Oct 2025 17:04:57 +0800 Subject: [PATCH] [1/N][Platform] Cleanup useless function (#26982) Signed-off-by: wangxiyuan --- tests/models/quantization/test_fp8.py | 7 ++- tests/quantization/test_compressed_tensors.py | 4 -- vllm/platforms/cuda.py | 45 +------------------ vllm/platforms/interface.py | 41 ++++++----------- vllm/platforms/rocm.py | 8 +--- vllm/platforms/tpu.py | 6 --- vllm/platforms/xpu.py | 16 ------- 7 files changed, 21 insertions(+), 106 deletions(-) diff --git a/tests/models/quantization/test_fp8.py b/tests/models/quantization/test_fp8.py index 55b149ae5da71..bac613913e91f 100644 --- a/tests/models/quantization/test_fp8.py +++ b/tests/models/quantization/test_fp8.py @@ -9,6 +9,7 @@ Note: these tests will only pass on L4 GPU. import pytest from tests.quantization.utils import is_quant_method_supported +from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 from vllm.platforms import current_platform from vllm.utils import STR_BACKEND_ENV_VAR @@ -69,8 +70,10 @@ def test_models( if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm(): pytest.skip(f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") - if not current_platform.is_kv_cache_dtype_supported(kv_cache_dtype, None): - pytest.skip(f"{kv_cache_dtype} is not supported on this platform.") + if not flash_attn_supports_fp8(): + pytest.skip( + f"{kv_cache_dtype} is not supported on this GPU type with {backend} attention." + ) with monkeypatch.context() as m: m.setenv("TOKENIZERS_PARALLELISM", "true") diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 1040cf70eb81e..e7d902ed26aaa 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -356,10 +356,6 @@ def test_compressed_tensors_fp8(vllm_runner): assert output -@pytest.mark.skipif( - not current_platform.is_kv_cache_dtype_supported("fp8", None), - reason="FP8 KV cache is not supported on this device.", -) @pytest.mark.skipif( not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." ) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c736e084a38df..03e6b3c295c7c 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -23,7 +23,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: from vllm.attention.backends.registry import _Backend - from vllm.config import ModelConfig, VllmConfig + from vllm.config import VllmConfig else: _Backend = None @@ -457,49 +457,6 @@ class CudaPlatformBase(Platform): def device_count(cls) -> int: return cuda_device_count_stateless() - @classmethod - def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" - ) -> bool: - fp8_attention = kv_cache_dtype.startswith("fp8") - attention_backend = envs.VLLM_ATTENTION_BACKEND - - supported = False - if model_config is not None and model_config.use_mla: - # Default to CutlassMLA for blackwell, - # FlashMLA otherwise - if attention_backend is None: - if cls.is_device_capability(100): - attention_backend = "CUTLASS_MLA" - else: - attention_backend = "FLASHMLA" - - # Only FlashMLA and CUTLASS_MLA support fp8 - if attention_backend in ["FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"]: - supported = True - else: - supported = not fp8_attention - else: - # Default to FlashAttention - if attention_backend is None: - attention_backend = "FLASH_ATTN" - - # All Blackwell backends support fp8 - if cls.is_device_capability(100): - supported = True - elif attention_backend == "FLASH_ATTN": - if fp8_attention: - from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 - - supported = flash_attn_supports_fp8() - else: - supported = True - elif attention_backend == "FLASHINFER": - supported = True - elif attention_backend == "TRITON_ATTN": - supported = cls.supports_fp8() - return supported - @classmethod def check_if_supports_dtype(cls, dtype: torch.dtype): if dtype == torch.bfloat16: # noqa: SIM102 diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f9f2cc4d34e2d..098e9058f5292 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -7,28 +7,23 @@ import platform import random import sys from datetime import timedelta -from platform import uname from typing import TYPE_CHECKING, Any, NamedTuple import numpy as np import torch -from torch.distributed import PrefixStore, ProcessGroup -from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger if TYPE_CHECKING: + from torch.distributed import PrefixStore, ProcessGroup + from vllm.attention.backends.registry import _Backend - from vllm.config import ModelConfig, VllmConfig + from vllm.config import VllmConfig + from vllm.inputs import ProcessorInputs, PromptType from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import FlexibleArgumentParser else: - _Backend = object - ModelConfig = object - VllmConfig = object - PoolingParams = object - SamplingParams = object FlexibleArgumentParser = object logger = init_logger(__name__) @@ -36,7 +31,7 @@ logger = init_logger(__name__) def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 - return "microsoft" in " ".join(uname()).lower() + return "microsoft" in " ".join(platform.uname()).lower() class PlatformEnum(enum.Enum): @@ -178,7 +173,8 @@ class Platform: import vllm._moe_C # noqa: F401 @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + # Import _Backend here to avoid circular import. from vllm.attention.backends.registry import _Backend return _Backend.TORCH_SDPA @@ -186,7 +182,7 @@ class Platform: @classmethod def get_attn_backend_cls( cls, - selected_backend: _Backend, + selected_backend: "_Backend", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -317,7 +313,7 @@ class Platform: pass @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ Check and update the configuration for the current platform. @@ -498,9 +494,9 @@ class Platform: @classmethod def validate_request( cls, - prompt: PromptType, - params: SamplingParams | PoolingParams, - processed_inputs: ProcessorInputs, + prompt: "PromptType", + params: "SamplingParams | PoolingParams", + processed_inputs: "ProcessorInputs", ) -> None: """Raises if this request is unsupported on this platform""" @@ -543,25 +539,16 @@ class Platform: def stateless_init_device_torch_dist_pg( cls, backend: str, - prefix_store: PrefixStore, + prefix_store: "PrefixStore", group_rank: int, group_size: int, timeout: timedelta, - ) -> ProcessGroup: + ) -> "ProcessGroup": """ Init platform-specific torch distributed process group. """ raise NotImplementedError - @classmethod - def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: ModelConfig - ) -> bool: - """ - Returns if the kv_cache_dtype is supported by the current platform. - """ - return False - @classmethod def check_if_supports_dtype(cls, dtype: torch.dtype): """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 9788bfeca109c..7aab0b76aa063 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -15,7 +15,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: from vllm.attention.backends.registry import _Backend - from vllm.config import ModelConfig, VllmConfig + from vllm.config import VllmConfig else: _Backend = None @@ -474,12 +474,6 @@ class RocmPlatform(Platform): def device_count(cls) -> int: return cuda_device_count_stateless() - @classmethod - def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" - ) -> bool: - return True - @classmethod def check_if_supports_dtype(cls, dtype: torch.dtype): if dtype == torch.bfloat16: # noqa: SIM102 diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index ed38f3bc30878..ab752f438f727 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -222,12 +222,6 @@ class TpuPlatform(Platform): ): raise ValueError("Torch XLA does not support per-request seed.") - @classmethod - def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" - ) -> bool: - return True - @classmethod @torch.compile(backend="openxla") def insert_blocks_to_device( diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 5799f97b8038d..db7c3549df519 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -86,22 +86,6 @@ class XPUPlatform(Platform): logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - @classmethod - def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" - ) -> bool: - """ - Check if the kv_cache_dtype is supported. - XPU only support fp8 kv cache with triton backend. - """ - if ( - envs.is_set("VLLM_ATTENTION_BACKEND") - and envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN" - ): - return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"] - - return False - @classmethod def set_device(cls, device: torch.device) -> None: """