[HybridKVCache][Platform] Add support_hybrid_kv_cache for platform (#24646)

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao 2025-09-11 21:47:58 +08:00 committed by GitHub
parent 94e6b2d55f
commit 4f6593b058
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 20 additions and 2 deletions

View File

@ -3529,8 +3529,7 @@ class VllmConfig:
# logger should only print warning message for hybrid models. As we # logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log # can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later. # warning message here and will log it later.
if not (current_platform.is_cuda() or current_platform.is_rocm() if not current_platform.support_hybrid_kv_cache():
or current_platform.is_cpu()):
# Hybrid KV cache manager is not supported on non-GPU platforms. # Hybrid KV cache manager is not supported on non-GPU platforms.
self.scheduler_config.disable_hybrid_kv_cache_manager = True self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_transfer_config is not None: if self.kv_transfer_config is not None:

View File

@ -347,3 +347,7 @@ class CpuPlatform(Platform):
@classmethod @classmethod
def opaque_attention_op(cls) -> bool: def opaque_attention_op(cls) -> bool:
return True return True
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True

View File

@ -571,6 +571,10 @@ class CudaPlatformBase(Platform):
"You can use float16 instead by explicitly setting the " "You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.") "`dtype` flag in CLI, for example: --dtype=half.")
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True
# NVML utils # NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

View File

@ -586,6 +586,13 @@ class Platform:
""" """
raise NotImplementedError raise NotImplementedError
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
"""
Returns if the hybrid kv cache is supported by the current platform.
"""
return False
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED

View File

@ -498,3 +498,7 @@ class RocmPlatform(Platform):
f"Your {gpu_name} GPU {compute_str}. " f"Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the " "You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.") "`dtype` flag in CLI, for example: --dtype=half.")
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True