diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 04e64422d2e0..45fb7f9580ae 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -246,11 +246,15 @@ class Qwen2_5_VisionAttention(nn.Module): # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA }: raise RuntimeError( f"Qwen2.5-VL does not support {self.attn_backend} backend now." ) + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] @@ -297,10 +301,13 @@ class Qwen2_5_VisionAttention(nn.Module): q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) - if self.attn_backend == _Backend.FLASH_ATTN: + if self.is_flash_attn_backend: # from vllm_flash_attn.flash_attn_interface import ( # flash_attn_varlen_func) - from flash_attn import flash_attn_varlen_func + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + from flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -311,7 +318,7 @@ class Qwen2_5_VisionAttention(nn.Module): cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, - dropout_p=0, + dropout_p=0.0, causal=False) context_layer = rearrange(output, @@ -635,7 +642,8 @@ class Qwen2_5_VisionTransformer(nn.Module): cu_seqlens: torch.Tensor, ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if (self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 4e8ea8e44913..40d77312b72c 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -274,10 +274,14 @@ class Qwen2VisionAttention(nn.Module): # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA }: raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now.") + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] @@ -324,10 +328,13 @@ class Qwen2VisionAttention(nn.Module): q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) - if self.attn_backend == _Backend.FLASH_ATTN: + if self.is_flash_attn_backend: # from vllm_flash_attn.flash_attn_interface import ( # flash_attn_varlen_func) - from flash_attn import flash_attn_varlen_func + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + from flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -338,7 +345,7 @@ class Qwen2VisionAttention(nn.Module): cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, - dropout_p=0, + dropout_p=0.0, causal=False) context_layer = rearrange(output, @@ -620,7 +627,8 @@ class Qwen2VisionTransformer(nn.Module): self, cu_seqlens: torch.Tensor ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if (self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index ac6a659bbaa3..de30509b1ccb 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -7,9 +7,7 @@ from typing import Final, Generic, Optional, Protocol, TypeVar, Union import torch from transformers import PretrainedConfig -import vllm.envs as envs -from vllm.attention.selector import (backend_name_to_enum, - get_global_forced_attn_backend) +from vllm.attention.selector import get_env_variable_attn_backend from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform @@ -75,32 +73,12 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend: Get the available attention backend for Vision Transformer. """ # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn. - selected_backend: Optional[_Backend] = get_global_forced_attn_backend() - if selected_backend is None: - backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND - if backend_by_env_var is not None: - selected_backend = backend_name_to_enum(backend_by_env_var) - if selected_backend is None: - if current_platform.is_cuda(): - device_available = current_platform.has_device_capability(80) - if device_available and support_fa: - from transformers.utils import is_flash_attn_2_available - if is_flash_attn_2_available(): - selected_backend = _Backend.FLASH_ATTN - else: - logger.warning_once( - "Current `vllm-flash-attn` has a bug inside vision " - "module, so we use xformers backend instead. You can " - "run `pip install flash-attn` to use flash-attention " - "backend.") - selected_backend = _Backend.XFORMERS - else: - # For Volta and Turing GPUs, use xformers instead. - selected_backend = _Backend.XFORMERS - else: - # Default to torch SDPA for other non-GPU platforms. - selected_backend = _Backend.TORCH_SDPA - return selected_backend + + selected_backend: Optional[_Backend] = get_env_variable_attn_backend() + if selected_backend is not None: + return selected_backend + + return current_platform.get_vit_attn_backend(support_fa) def resolve_visual_encoder_outputs( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 87ff6b385809..a90910639f78 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -206,6 +206,20 @@ class CudaPlatformBase(Platform): torch.cuda.reset_peak_memory_stats(device) return torch.cuda.max_memory_allocated(device) + @classmethod + def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: + if cls.has_device_capability(80) and support_fa: + from transformers.utils import is_flash_attn_2_available + if is_flash_attn_2_available(): + return _Backend.FLASH_ATTN + logger.warning_once( + "Current `vllm-flash-attn` has a bug inside vision " + "module, so we use xformers backend instead. You can " + "run `pip install flash-attn` to use flash-attention " + "backend.") + # Fallback for Volta/Turing GPUs or FA not supported + return _Backend.XFORMERS + @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 6bae0fe25c79..997aee7063f5 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -46,6 +46,7 @@ class _Backend(enum.Enum): ROCM_FLASH = enum.auto() ROCM_AITER_MLA = enum.auto() # Supported by V1 ROCM_AITER_MLA_VLLM_V1 = enum.auto() + ROCM_AITER_FA = enum.auto() # used for ViT attn backend TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() FLASHINFER_VLLM_V1 = enum.auto() @@ -186,6 +187,10 @@ class Platform: else: return device_id + @classmethod + def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: + return _Backend.TORCH_SDPA + @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b2e69f60343f..54ffc83cd565 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -173,6 +173,18 @@ class RocmPlatform(Platform): "quark", "ptpc_fp8" ] + @classmethod + def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: + if support_fa: + if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA + and on_gfx9()): + # Note: AITER FA is only supported for Qwen-VL models. + # TODO: Add support for other VL models in their model class. + return _Backend.ROCM_AITER_FA + if on_gfx9(): + return _Backend.FLASH_ATTN + return _Backend.TORCH_SDPA + @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,