From b42566f44039ed8dc5f6c124203d5d71740c6fe0 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 15 Sep 2025 22:10:55 -0400 Subject: [PATCH] [Bug] Fix `is_flashmla_supported` Check Error (#24774) Signed-off-by: yewentao256 --- vllm/attention/backends/flashmla.py | 15 ++------------- vllm/v1/attention/backends/mla/flashmla.py | 15 ++------------- 2 files changed, 4 insertions(+), 26 deletions(-) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index 411eb5413f53c..aeaa0ab631cfb 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -17,7 +17,6 @@ from vllm.attention.backends.mla.common import (MLACommonBackend, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) -from vllm.platforms.cuda import CudaPlatform class FlashMLABackend(MLACommonBackend): @@ -179,18 +178,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - assert is_flashmla_supported(), \ - "FlashMLA is not supported on this device" - - # disallow FlashMLA on NVIDIA Blackwell (SM 10.0+) GPUs - # context: - # https://github.com/deepseek-ai/FlashMLA/issues/83 - # https://github.com/vllm-project/vllm/issues/24513 - if CudaPlatform.has_device_capability(100): - raise NotImplementedError( - "FlashMLA is temporarily disabled on Blackwell (SM 10.0). " - "Please use CUTLASS_MLA or TRITON_MLA instead. " - "Example: `export VLLM_ATTENTION_BACKEND=CUTLASS_MLA`") + is_supported, reason = is_flashmla_supported() + assert is_supported, reason unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 549af1a062252..150e38553e4bb 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -12,7 +12,6 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, is_flashmla_supported) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.platforms.cuda import CudaPlatform from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, MLACommonImpl, @@ -156,18 +155,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - assert is_flashmla_supported(), \ - "FlashMLA is not supported on this device" - - # disallow FlashMLA on NVIDIA Blackwell (SM 10.0+) GPUs - # context: - # https://github.com/deepseek-ai/FlashMLA/issues/83 - # https://github.com/vllm-project/vllm/issues/24513 - if CudaPlatform.has_device_capability(100): - raise NotImplementedError( - "FlashMLA is temporarily disabled on Blackwell (SM 10.0). " - "Please use CUTLASS_MLA or TRITON_MLA instead. " - "Example: `export VLLM_ATTENTION_BACKEND=CUTLASS_MLA`") + is_supported, reason = is_flashmla_supported() + assert is_supported, reason unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features):