From 7337ec6c9f79ed5345feb881763317f9906ae964 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 22 Sep 2025 14:27:51 -0400 Subject: [PATCH] [CI Failure] Fix fp8 kv cache on Signed-off-by: yewentao256 --- vllm/platforms/cuda.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 7baa5a9742f44..b10bc03ee16c6 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -286,6 +286,9 @@ class CudaPlatformBase(Platform): TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 + use_fp8_kv_cache = (kv_cache_dtype is not None + and kv_cache_dtype.startswith("fp8")) + if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") if cls.has_device_capability(100): @@ -334,10 +337,11 @@ class CudaPlatformBase(Platform): # FlashAttention is the default for SM 8.0+ GPUs if cls.has_device_capability(80): - if has_sink and not cls.is_device_capability(90): + if (has_sink or + use_fp8_kv_cache) and not cls.is_device_capability(90): logger.info_once("Using Triton backend on V1 engine.") return TRITON_ATTN_VLLM_V1 - if is_default_backend_supported := is_attn_backend_supported( + elif is_default_backend_supported := is_attn_backend_supported( FLASH_ATTN_V1, head_size, dtype, allow_import_error=False): logger.info_once("Using Flash Attention backend on "