From ab5b6459df40bda3157300c890fd1057ad8f96a9 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 28 Sep 2025 23:03:51 -0700 Subject: [PATCH] [Bugfix] Fallback ViT attn backend to SDPA for blackwell (#25851) Signed-off-by: Roger Wang Signed-off-by: simon-mo --- vllm/model_executor/models/qwen3_vl.py | 10 +--------- vllm/platforms/cuda.py | 6 ++++++ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 8041448ed09a..80381270f419 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -66,7 +66,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend, current_platform +from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm.utils import is_list_of @@ -335,14 +335,6 @@ class Qwen3_VisionTransformer(nn.Module): }: raise RuntimeError( f"Qwen3-VL does not support {self.attn_backend} backend now.") - if current_platform.is_device_capability( - 100) and self.attn_backend != _Backend.TORCH_SDPA: - # TODO(Roger/Wentao): remove this after FA - # or XFORMERS's issue fixed on Blackwell - logger.info_once("Qwen3-VL vision attention does not support " - f"{self.attn_backend} backend on Blackwell now. " - "Vision attention backend is set to TORCH_SDPA.") - self.attn_backend = _Backend.TORCH_SDPA self.blocks = nn.ModuleList([ Qwen3_VisionBlock( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 58ba08101bc9..8b9f9f569206 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -205,6 +205,12 @@ class CudaPlatformBase(Platform): @classmethod def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: + + # For Blackwell GPUs, force TORCH_SDPA for now. + # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 + if cls.has_device_capability(100): + return _Backend.TORCH_SDPA + if dtype not in (torch.float16, torch.bfloat16): return _Backend.XFORMERS