diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index f23c096952ce..411eb5413f53 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -17,6 +17,7 @@ from vllm.attention.backends.mla.common import (MLACommonBackend, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) +from vllm.platforms.cuda import CudaPlatform class FlashMLABackend(MLACommonBackend): @@ -181,6 +182,16 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): assert is_flashmla_supported(), \ "FlashMLA is not supported on this device" + # disallow FlashMLA on NVIDIA Blackwell (SM 10.0+) GPUs + # context: + # https://github.com/deepseek-ai/FlashMLA/issues/83 + # https://github.com/vllm-project/vllm/issues/24513 + if CudaPlatform.has_device_capability(100): + raise NotImplementedError( + "FlashMLA is temporarily disabled on Blackwell (SM 10.0). " + "Please use CUTLASS_MLA or TRITON_MLA instead. " + "Example: `export VLLM_ATTENTION_BACKEND=CUTLASS_MLA`") + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 2f13f19218d9..549af1a06225 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, is_flashmla_supported) from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.platforms.cuda import CudaPlatform from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, MLACommonImpl, @@ -158,6 +159,16 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): assert is_flashmla_supported(), \ "FlashMLA is not supported on this device" + # disallow FlashMLA on NVIDIA Blackwell (SM 10.0+) GPUs + # context: + # https://github.com/deepseek-ai/FlashMLA/issues/83 + # https://github.com/vllm-project/vllm/issues/24513 + if CudaPlatform.has_device_capability(100): + raise NotImplementedError( + "FlashMLA is temporarily disabled on Blackwell (SM 10.0). " + "Please use CUTLASS_MLA or TRITON_MLA instead. " + "Example: `export VLLM_ATTENTION_BACKEND=CUTLASS_MLA`") + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError(