diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index f9c2d4f49835d..b6b93ff4a0ac0 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -16,6 +16,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -23,6 +24,34 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op +logger = init_logger(__name__) +USE_XFORMERS_OPS = None + + +def check_xformers_availability(): + global USE_XFORMERS_OPS + if USE_XFORMERS_OPS is not None: + return USE_XFORMERS_OPS + + if current_platform.is_cuda() and current_platform.has_device_capability( + 100): + # Xformers FA is not compatible with B200 + USE_XFORMERS_OPS = False + else: + try: + from importlib.util import find_spec + + find_spec("xformers.ops") + USE_XFORMERS_OPS = True + except ImportError: + USE_XFORMERS_OPS = False + + # the warning only needs to be shown once + if not USE_XFORMERS_OPS: + logger.warning("Xformers is not available, falling back.") + + return USE_XFORMERS_OPS + class Attention(nn.Module): """Attention layer. @@ -314,6 +343,10 @@ class MultiHeadAttention(nn.Module): _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 } else _Backend.TORCH_SDPA + if (self.attn_backend == _Backend.XFORMERS + and not check_xformers_availability()): + self.attn_backend = _Backend.TORCH_SDPA + def forward( self, query: torch.Tensor,