[1/N][Platform] Cleanup useless function (#26982)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-10-22 17:04:57 +08:00 committed by GitHub
parent ab3e80042e
commit f6027b2855
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 21 additions and 106 deletions

View File

@ -9,6 +9,7 @@ Note: these tests will only pass on L4 GPU.
import pytest
from tests.quantization.utils import is_quant_method_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR
@ -69,8 +70,10 @@ def test_models(
if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm():
pytest.skip(f"{kv_cache_dtype} is currently not supported on ROCm/HIP.")
if not current_platform.is_kv_cache_dtype_supported(kv_cache_dtype, None):
pytest.skip(f"{kv_cache_dtype} is not supported on this platform.")
if not flash_attn_supports_fp8():
pytest.skip(
f"{kv_cache_dtype} is not supported on this GPU type with {backend} attention."
)
with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", "true")

View File

@ -356,10 +356,6 @@ def test_compressed_tensors_fp8(vllm_runner):
assert output
@pytest.mark.skipif(
not current_platform.is_kv_cache_dtype_supported("fp8", None),
reason="FP8 KV cache is not supported on this device.",
)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)

View File

@ -23,7 +23,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
from vllm.config import VllmConfig
else:
_Backend = None
@ -457,49 +457,6 @@ class CudaPlatformBase(Platform):
def device_count(cls) -> int:
return cuda_device_count_stateless()
@classmethod
def is_kv_cache_dtype_supported(
cls, kv_cache_dtype: str, model_config: "ModelConfig"
) -> bool:
fp8_attention = kv_cache_dtype.startswith("fp8")
attention_backend = envs.VLLM_ATTENTION_BACKEND
supported = False
if model_config is not None and model_config.use_mla:
# Default to CutlassMLA for blackwell,
# FlashMLA otherwise
if attention_backend is None:
if cls.is_device_capability(100):
attention_backend = "CUTLASS_MLA"
else:
attention_backend = "FLASHMLA"
# Only FlashMLA and CUTLASS_MLA support fp8
if attention_backend in ["FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"]:
supported = True
else:
supported = not fp8_attention
else:
# Default to FlashAttention
if attention_backend is None:
attention_backend = "FLASH_ATTN"
# All Blackwell backends support fp8
if cls.is_device_capability(100):
supported = True
elif attention_backend == "FLASH_ATTN":
if fp8_attention:
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
supported = flash_attn_supports_fp8()
else:
supported = True
elif attention_backend == "FLASHINFER":
supported = True
elif attention_backend == "TRITON_ATTN":
supported = cls.supports_fp8()
return supported
@classmethod
def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102

View File

@ -7,28 +7,23 @@ import platform
import random
import sys
from datetime import timedelta
from platform import uname
from typing import TYPE_CHECKING, Any, NamedTuple
import numpy as np
import torch
from torch.distributed import PrefixStore, ProcessGroup
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import FlexibleArgumentParser
else:
_Backend = object
ModelConfig = object
VllmConfig = object
PoolingParams = object
SamplingParams = object
FlexibleArgumentParser = object
logger = init_logger(__name__)
@ -36,7 +31,7 @@ logger = init_logger(__name__)
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower()
return "microsoft" in " ".join(platform.uname()).lower()
class PlatformEnum(enum.Enum):
@ -178,7 +173,8 @@ class Platform:
import vllm._moe_C # noqa: F401
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
# Import _Backend here to avoid circular import.
from vllm.attention.backends.registry import _Backend
return _Backend.TORCH_SDPA
@ -186,7 +182,7 @@ class Platform:
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: _Backend,
selected_backend: "_Backend",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
@ -317,7 +313,7 @@ class Platform:
pass
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Check and update the configuration for the current platform.
@ -498,9 +494,9 @@ class Platform:
@classmethod
def validate_request(
cls,
prompt: PromptType,
params: SamplingParams | PoolingParams,
processed_inputs: ProcessorInputs,
prompt: "PromptType",
params: "SamplingParams | PoolingParams",
processed_inputs: "ProcessorInputs",
) -> None:
"""Raises if this request is unsupported on this platform"""
@ -543,25 +539,16 @@ class Platform:
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
prefix_store: "PrefixStore",
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
) -> "ProcessGroup":
"""
Init platform-specific torch distributed process group.
"""
raise NotImplementedError
@classmethod
def is_kv_cache_dtype_supported(
cls, kv_cache_dtype: str, model_config: ModelConfig
) -> bool:
"""
Returns if the kv_cache_dtype is supported by the current platform.
"""
return False
@classmethod
def check_if_supports_dtype(cls, dtype: torch.dtype):
"""

View File

@ -15,7 +15,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
from vllm.config import VllmConfig
else:
_Backend = None
@ -474,12 +474,6 @@ class RocmPlatform(Platform):
def device_count(cls) -> int:
return cuda_device_count_stateless()
@classmethod
def is_kv_cache_dtype_supported(
cls, kv_cache_dtype: str, model_config: "ModelConfig"
) -> bool:
return True
@classmethod
def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102

View File

@ -222,12 +222,6 @@ class TpuPlatform(Platform):
):
raise ValueError("Torch XLA does not support per-request seed.")
@classmethod
def is_kv_cache_dtype_supported(
cls, kv_cache_dtype: str, model_config: "ModelConfig"
) -> bool:
return True
@classmethod
@torch.compile(backend="openxla")
def insert_blocks_to_device(

View File

@ -86,22 +86,6 @@ class XPUPlatform(Platform):
logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
@classmethod
def is_kv_cache_dtype_supported(
cls, kv_cache_dtype: str, model_config: "ModelConfig"
) -> bool:
"""
Check if the kv_cache_dtype is supported.
XPU only support fp8 kv cache with triton backend.
"""
if (
envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN"
):
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
return False
@classmethod
def set_device(cls, device: torch.device) -> None:
"""