diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index cc2543538d0d..9a8941e3cdd1 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -456,6 +456,19 @@ class CudaPlatformBase(Platform): def device_count(cls) -> int: return cuda_device_count_stateless() + @classmethod + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + fp8_attention = kv_cache_dtype.startswith("fp8") + will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND") + ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + supported = False + if cls.is_device_capability(100): + supported = True + elif fp8_attention and will_use_fa: + from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 + supported = flash_attn_supports_fp8() + return supported + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, @@ -583,19 +596,6 @@ class NonNvmlCudaPlatform(CudaPlatformBase): " not found. Assuming no NVLink available.") return False - @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: - fp8_attention = kv_cache_dtype.startswith("fp8") - will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND") - ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" - supported = False - if cls.is_device_capability(100): - supported = True - elif fp8_attention and will_use_fa: - from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 - supported = flash_attn_supports_fp8() - return supported - # Autodetect either NVML-enabled or non-NVML platform # based on whether NVML is available.