[Feature] Disallow FlashMLA on Blackwell (#24521)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Wentao Ye 2025-09-09 14:59:34 -04:00 committed by GitHub
parent b8a93076d3
commit 15de5ff9ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 0 deletions

View File

@ -17,6 +17,7 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
is_flashmla_supported) is_flashmla_supported)
from vllm.platforms.cuda import CudaPlatform
class FlashMLABackend(MLACommonBackend): class FlashMLABackend(MLACommonBackend):
@ -181,6 +182,16 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert is_flashmla_supported(), \ assert is_flashmla_supported(), \
"FlashMLA is not supported on this device" "FlashMLA is not supported on this device"
# disallow FlashMLA on NVIDIA Blackwell (SM 10.0+) GPUs
# context:
# https://github.com/deepseek-ai/FlashMLA/issues/83
# https://github.com/vllm-project/vllm/issues/24513
if CudaPlatform.has_device_capability(100):
raise NotImplementedError(
"FlashMLA is temporarily disabled on Blackwell (SM 10.0). "
"Please use CUTLASS_MLA or TRITON_MLA instead. "
"Example: `export VLLM_ATTENTION_BACKEND=CUTLASS_MLA`")
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features): if any(unsupported_features):
raise NotImplementedError( raise NotImplementedError(

View File

@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
is_flashmla_supported) is_flashmla_supported)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms.cuda import CudaPlatform
from vllm.v1.attention.backends.mla.common import (MLACommonBackend, from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
MLACommonImpl, MLACommonImpl,
@ -158,6 +159,16 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert is_flashmla_supported(), \ assert is_flashmla_supported(), \
"FlashMLA is not supported on this device" "FlashMLA is not supported on this device"
# disallow FlashMLA on NVIDIA Blackwell (SM 10.0+) GPUs
# context:
# https://github.com/deepseek-ai/FlashMLA/issues/83
# https://github.com/vllm-project/vllm/issues/24513
if CudaPlatform.has_device_capability(100):
raise NotImplementedError(
"FlashMLA is temporarily disabled on Blackwell (SM 10.0). "
"Please use CUTLASS_MLA or TRITON_MLA instead. "
"Example: `export VLLM_ATTENTION_BACKEND=CUTLASS_MLA`")
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features): if any(unsupported_features):
raise NotImplementedError( raise NotImplementedError(