diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 52b33849dae6..e40b6eb2b5a4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -146,6 +146,7 @@ class CudaPlatformBase(Platform): # required block_size. use_flashmla = False use_cutlass_mla = False + use_flashinfer_mla = False if envs.VLLM_ATTENTION_BACKEND is None: # Default case @@ -164,6 +165,8 @@ class CudaPlatformBase(Platform): use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") use_cutlass_mla = ( envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") + use_flashinfer_mla = ( + envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA") from vllm.attention.ops.flashmla import is_flashmla_supported if use_flashmla and is_flashmla_supported()[0] \ @@ -176,6 +179,11 @@ class CudaPlatformBase(Platform): cache_config.block_size = 128 logger.info("Forcing kv cache block size to 128 for " "CUTLASS_MLA backend.") + if use_flashinfer_mla and cache_config.block_size not in [32, 64]: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashInferMLA " + "backend.") # lazy import to avoid circular import from vllm.config import CUDAGraphMode @@ -228,8 +236,9 @@ class CudaPlatformBase(Platform): use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( selected_backend is None and cls.is_device_capability(100) and block_size == 128) - use_flashinfermla = (selected_backend == _Backend.FLASHINFER_MLA - and cls.has_device_capability(100)) + use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( + selected_backend is None and cls.is_device_capability(100) + and block_size in [32, 64]) use_flashmla = selected_backend in [ _Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1 ] or (selected_backend is None and is_flashmla_supported()[0])