mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 10:55:50 +08:00
Fix kv_cache_dtype handling for out-of-tree HPU plugin (#21302)
Signed-off-by: Konrad Zawora <kzawora@habana.ai> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
This commit is contained in:
parent
6e5b5ca580
commit
c17231e827
@ -1352,22 +1352,8 @@ class EngineArgs:
|
|||||||
|
|
||||||
# No Fp8 KV cache so far.
|
# No Fp8 KV cache so far.
|
||||||
if self.kv_cache_dtype != "auto":
|
if self.kv_cache_dtype != "auto":
|
||||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
supported = current_platform.is_kv_cache_dtype_supported(
|
||||||
will_use_fa = (
|
self.kv_cache_dtype)
|
||||||
current_platform.is_cuda()
|
|
||||||
and not envs.is_set("VLLM_ATTENTION_BACKEND")
|
|
||||||
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
|
||||||
supported = False
|
|
||||||
if (current_platform.is_rocm()
|
|
||||||
or (current_platform.is_cuda()
|
|
||||||
and current_platform.is_device_capability(100))
|
|
||||||
or current_platform.is_tpu()):
|
|
||||||
supported = True
|
|
||||||
elif fp8_attention and will_use_fa:
|
|
||||||
from vllm.attention.utils.fa_utils import (
|
|
||||||
flash_attn_supports_fp8)
|
|
||||||
supported = flash_attn_supports_fp8()
|
|
||||||
|
|
||||||
if not supported:
|
if not supported:
|
||||||
_raise_or_fallback(feature_name="--kv-cache-dtype",
|
_raise_or_fallback(feature_name="--kv-cache-dtype",
|
||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
|
|||||||
@ -586,6 +586,19 @@ class NonNvmlCudaPlatform(CudaPlatformBase):
|
|||||||
" not found. Assuming no NVLink available.")
|
" not found. Assuming no NVLink available.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
||||||
|
fp8_attention = kv_cache_dtype.startswith("fp8")
|
||||||
|
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||||
|
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
||||||
|
supported = False
|
||||||
|
if cls.is_device_capability(100):
|
||||||
|
supported = True
|
||||||
|
elif fp8_attention and will_use_fa:
|
||||||
|
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
|
||||||
|
supported = flash_attn_supports_fp8()
|
||||||
|
return supported
|
||||||
|
|
||||||
|
|
||||||
# Autodetect either NVML-enabled or non-NVML platform
|
# Autodetect either NVML-enabled or non-NVML platform
|
||||||
# based on whether NVML is available.
|
# based on whether NVML is available.
|
||||||
|
|||||||
@ -543,6 +543,13 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
||||||
|
"""
|
||||||
|
Returns if the kv_cache_dtype is supported by the current platform.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class UnspecifiedPlatform(Platform):
|
class UnspecifiedPlatform(Platform):
|
||||||
_enum = PlatformEnum.UNSPECIFIED
|
_enum = PlatformEnum.UNSPECIFIED
|
||||||
|
|||||||
@ -454,3 +454,7 @@ class RocmPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def device_count(cls) -> int:
|
def device_count(cls) -> int:
|
||||||
return cuda_device_count_stateless()
|
return cuda_device_count_stateless()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
||||||
|
return True
|
||||||
@ -190,6 +190,10 @@ class TpuPlatform(Platform):
|
|||||||
and params.sampling_type == SamplingType.RANDOM_SEED):
|
and params.sampling_type == SamplingType.RANDOM_SEED):
|
||||||
raise ValueError("Torch XLA does not support per-request seed.")
|
raise ValueError("Torch XLA does not support per-request seed.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user