From b0d541947ab5b8c361077ee85df7b3a2138472f9 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 8 Mar 2025 21:18:39 -0500 Subject: [PATCH] [Attention] Default to FlashMLA backend for MLA (#14451) Signed-off-by: Lucas Wilkinson Signed-off-by: Tyler Michael Smith Co-authored-by: Tyler Michael Smith --- vllm/platforms/cuda.py | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 4be93148139d7..1bba99088bb29 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -112,6 +112,7 @@ class CudaPlatformBase(Platform): parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config compilation_config = vllm_config.compilation_config + model_config = vllm_config.model_config if parallel_config.worker_cls == "auto": if scheduler_config.is_multi_step: @@ -142,14 +143,21 @@ class CudaPlatformBase(Platform): cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: cache_config.block_size = 16 + # TODO(lucas): handle this more gracefully - if envs.VLLM_ATTENTION_BACKEND is not None \ - and envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" \ - and cache_config.block_size != 64: - cache_config.block_size = 64 - logger.info( - "FlashMLA: Forcing kv cache block size to 64 since this" - " is currently the only block size supported by the kernel.") + # Note: model_config may be None during testing + if model_config is not None and model_config.use_mla: + # if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then + # we default to FlashMLA backend, so we need to force the blocksize + # here + use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \ + or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") + from vllm.attention.backends.flashmla import is_flashmla_supported + if use_flashmla and is_flashmla_supported()[0] \ + and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashMLA backend.") if (parallel_config.data_parallel_size > 1 and compilation_config.use_cudagraph): @@ -173,7 +181,15 @@ class CudaPlatformBase(Platform): if use_mla: # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here - if selected_backend == _Backend.FLASHMLA: + if selected_backend == _Backend.TRITON_MLA or block_size != 64: + if use_v1: + logger.info_once("Using Triton MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "triton_mla.TritonMLABackend") + else: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" + else: from vllm.attention.backends.flashmla import ( is_flashmla_supported) if not is_flashmla_supported()[0]: @@ -195,14 +211,6 @@ class CudaPlatformBase(Platform): logger.info("Using FlashMLA backend.") return ("vllm.attention.backends." "flashmla.FlashMLABackend") - - if use_v1: - logger.info_once("Using Triton MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "triton_mla.TritonMLABackend") - else: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" if use_v1: logger.info_once("Using Flash Attention backend on V1 engine.") return ("vllm.v1.attention.backends.flash_attn."