[Bugfix] Voxtral on Blackwell GPUs (RTX 50 series) (#21077)

Signed-off-by: hax0r31337 <liulihaocaiqwq@gmail.com>
This commit is contained in:
hax0r31337 2025-07-19 00:40:18 +02:00 committed by GitHub
parent 0f199f197b
commit 5782581acf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,6 +16,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@ -23,6 +24,34 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
def check_xformers_availability():
global USE_XFORMERS_OPS
if USE_XFORMERS_OPS is not None:
return USE_XFORMERS_OPS
if current_platform.is_cuda() and current_platform.has_device_capability(
100):
# Xformers FA is not compatible with B200
USE_XFORMERS_OPS = False
else:
try:
from importlib.util import find_spec
find_spec("xformers.ops")
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
# the warning only needs to be shown once
if not USE_XFORMERS_OPS:
logger.warning("Xformers is not available, falling back.")
return USE_XFORMERS_OPS
class Attention(nn.Module):
"""Attention layer.
@ -314,6 +343,10 @@ class MultiHeadAttention(nn.Module):
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA
if (self.attn_backend == _Backend.XFORMERS
and not check_xformers_availability()):
self.attn_backend = _Backend.TORCH_SDPA
def forward(
self,
query: torch.Tensor,