diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 9a8941e3cdd19..c35d22c1d6824 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -150,11 +150,28 @@ class CudaPlatformBase(Platform): # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing if model_config is not None and model_config.use_mla: - # if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then - # we default to FlashMLA backend, so we need to force the blocksize - # here - use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \ - or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") + # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, + # then we default to FlashMLA backend for non-blackwell GPUs, + # else we default to CutlassMLA. For each case, we force the + # required block_size. + use_flashmla = False + use_cutlass_mla = False + + if envs.VLLM_ATTENTION_BACKEND is None: + # Default case + if cls.is_device_capability(100): + # Blackwell => Force CutlassMLA. + use_cutlass_mla = True + envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA_VLLM_V1" + else: + # Not Blackwell + use_flashmla = True + else: + # Forced case + use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") + use_cutlass_mla = ( + envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1") + from vllm.attention.ops.flashmla import is_flashmla_supported if use_flashmla and is_flashmla_supported()[0] \ and cache_config.block_size != 64: @@ -162,8 +179,6 @@ class CudaPlatformBase(Platform): logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") - use_cutlass_mla = (envs.VLLM_ATTENTION_BACKEND is not None \ - and envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1") if use_cutlass_mla and cache_config.block_size != 128: cache_config.block_size = 128 logger.info("Forcing kv cache block size to 128 for "