From c17231e827991d5778e8ed258e7cdcb12c35b149 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 22 Jul 2025 08:35:14 +0200 Subject: [PATCH] Fix kv_cache_dtype handling for out-of-tree HPU plugin (#21302) Signed-off-by: Konrad Zawora Signed-off-by: Chendi.Xue Co-authored-by: Chendi.Xue --- vllm/engine/arg_utils.py | 18 ++---------------- vllm/platforms/cuda.py | 13 +++++++++++++ vllm/platforms/interface.py | 7 +++++++ vllm/platforms/rocm.py | 4 ++++ vllm/platforms/tpu.py | 4 ++++ 5 files changed, 30 insertions(+), 16 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 28b1c1c363a76..1f74d22d07c1c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1352,22 +1352,8 @@ class EngineArgs: # No Fp8 KV cache so far. if self.kv_cache_dtype != "auto": - fp8_attention = self.kv_cache_dtype.startswith("fp8") - will_use_fa = ( - current_platform.is_cuda() - and not envs.is_set("VLLM_ATTENTION_BACKEND") - ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" - supported = False - if (current_platform.is_rocm() - or (current_platform.is_cuda() - and current_platform.is_device_capability(100)) - or current_platform.is_tpu()): - supported = True - elif fp8_attention and will_use_fa: - from vllm.attention.utils.fa_utils import ( - flash_attn_supports_fp8) - supported = flash_attn_supports_fp8() - + supported = current_platform.is_kv_cache_dtype_supported( + self.kv_cache_dtype) if not supported: _raise_or_fallback(feature_name="--kv-cache-dtype", recommend_to_remove=False) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 962e2b3aab601..fdf1f46e603b4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -586,6 +586,19 @@ class NonNvmlCudaPlatform(CudaPlatformBase): " not found. Assuming no NVLink available.") return False + @classmethod + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + fp8_attention = kv_cache_dtype.startswith("fp8") + will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND") + ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + supported = False + if cls.is_device_capability(100): + supported = True + elif fp8_attention and will_use_fa: + from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 + supported = flash_attn_supports_fp8() + return supported + # Autodetect either NVML-enabled or non-NVML platform # based on whether NVML is available. diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 1cd5cb5e83db7..02cc392244bac 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -543,6 +543,13 @@ class Platform: """ raise RuntimeError(f"Unsupported torch distributed backend: {backend}") + @classmethod + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + """ + Returns if the kv_cache_dtype is supported by the current platform. + """ + return False + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 0bf9262776b18..b2e69f60343f6 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -454,3 +454,7 @@ class RocmPlatform(Platform): @classmethod def device_count(cls) -> int: return cuda_device_count_stateless() + + @classmethod + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + return True \ No newline at end of file diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index febc6ae4662bf..146801c9d7739 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -190,6 +190,10 @@ class TpuPlatform(Platform): and params.sampling_type == SamplingType.RANDOM_SEED): raise ValueError("Torch XLA does not support per-request seed.") + @classmethod + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + return True + try: from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform