mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 03:45:01 +08:00
[Bugfix][CUDA] fixes CUDA FP8 kv cache dtype supported (#21420)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
parent
08d2bd78da
commit
2dec7c1a5d
@ -456,6 +456,19 @@ class CudaPlatformBase(Platform):
|
|||||||
def device_count(cls) -> int:
|
def device_count(cls) -> int:
|
||||||
return cuda_device_count_stateless()
|
return cuda_device_count_stateless()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
||||||
|
fp8_attention = kv_cache_dtype.startswith("fp8")
|
||||||
|
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||||
|
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
||||||
|
supported = False
|
||||||
|
if cls.is_device_capability(100):
|
||||||
|
supported = True
|
||||||
|
elif fp8_attention and will_use_fa:
|
||||||
|
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
|
||||||
|
supported = flash_attn_supports_fp8()
|
||||||
|
return supported
|
||||||
|
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
@ -583,19 +596,6 @@ class NonNvmlCudaPlatform(CudaPlatformBase):
|
|||||||
" not found. Assuming no NVLink available.")
|
" not found. Assuming no NVLink available.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
|
||||||
fp8_attention = kv_cache_dtype.startswith("fp8")
|
|
||||||
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
|
|
||||||
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
|
||||||
supported = False
|
|
||||||
if cls.is_device_capability(100):
|
|
||||||
supported = True
|
|
||||||
elif fp8_attention and will_use_fa:
|
|
||||||
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
|
|
||||||
supported = flash_attn_supports_fp8()
|
|
||||||
return supported
|
|
||||||
|
|
||||||
|
|
||||||
# Autodetect either NVML-enabled or non-NVML platform
|
# Autodetect either NVML-enabled or non-NVML platform
|
||||||
# based on whether NVML is available.
|
# based on whether NVML is available.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user