[FEAT][ROCm] Enable running Flash Attention as ViT attn backend for Qwen-VL models on ROCm platform. (#22069)

Signed-off-by: tjtanaavllm <tunjian.tan@amd.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
This commit is contained in:
vllmellm 2025-08-02 14:53:18 +08:00 committed by GitHub
parent 0edaf752d7
commit d3a6f2120b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 64 additions and 39 deletions

View File

@ -246,11 +246,15 @@ class Qwen2_5_VisionAttention(nn.Module):
# Detect attention implementation. # Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}: }:
raise RuntimeError( raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now." f"Qwen2.5-VL does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim] # [s, b, 3 * head * head_dim]
@ -297,9 +301,12 @@ class Qwen2_5_VisionAttention(nn.Module):
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self.attn_backend == _Backend.FLASH_ATTN: if self.is_flash_attn_backend:
# from vllm_flash_attn.flash_attn_interface import ( # from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func) # flash_attn_varlen_func)
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
@ -311,7 +318,7 @@ class Qwen2_5_VisionAttention(nn.Module):
cu_seqlens_k=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen, max_seqlen_k=max_seqlen,
dropout_p=0, dropout_p=0.0,
causal=False) causal=False)
context_layer = rearrange(output, context_layer = rearrange(output,
@ -635,7 +642,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]: ) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN: if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

View File

@ -274,10 +274,14 @@ class Qwen2VisionAttention(nn.Module):
# Detect attention implementation. # Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}: }:
raise RuntimeError( raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now.") f"Qwen2-VL does not support {self.attn_backend} backend now.")
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim] # [s, b, 3 * head * head_dim]
@ -324,9 +328,12 @@ class Qwen2VisionAttention(nn.Module):
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self.attn_backend == _Backend.FLASH_ATTN: if self.is_flash_attn_backend:
# from vllm_flash_attn.flash_attn_interface import ( # from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func) # flash_attn_varlen_func)
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
@ -338,7 +345,7 @@ class Qwen2VisionAttention(nn.Module):
cu_seqlens_k=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen, max_seqlen_k=max_seqlen,
dropout_p=0, dropout_p=0.0,
causal=False) causal=False)
context_layer = rearrange(output, context_layer = rearrange(output,
@ -620,7 +627,8 @@ class Qwen2VisionTransformer(nn.Module):
self, cu_seqlens: torch.Tensor self, cu_seqlens: torch.Tensor
) -> tuple[Optional[int], Optional[list[int]]]: ) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN: if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

View File

@ -7,9 +7,7 @@ from typing import Final, Generic, Optional, Protocol, TypeVar, Union
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
import vllm.envs as envs from vllm.attention.selector import get_env_variable_attn_backend
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform from vllm.platforms import _Backend, current_platform
@ -75,33 +73,13 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
Get the available attention backend for Vision Transformer. Get the available attention backend for Vision Transformer.
""" """
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn. # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None: selected_backend: Optional[_Backend] = get_env_variable_attn_backend()
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND if selected_backend is not None:
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
if current_platform.is_cuda():
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
logger.warning_once(
"Current `vllm-flash-attn` has a bug inside vision "
"module, so we use xformers backend instead. You can "
"run `pip install flash-attn` to use flash-attention "
"backend.")
selected_backend = _Backend.XFORMERS
else:
# For Volta and Turing GPUs, use xformers instead.
selected_backend = _Backend.XFORMERS
else:
# Default to torch SDPA for other non-GPU platforms.
selected_backend = _Backend.TORCH_SDPA
return selected_backend return selected_backend
return current_platform.get_vit_attn_backend(support_fa)
def resolve_visual_encoder_outputs( def resolve_visual_encoder_outputs(
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]], encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],

View File

@ -206,6 +206,20 @@ class CudaPlatformBase(Platform):
torch.cuda.reset_peak_memory_stats(device) torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device) return torch.cuda.max_memory_allocated(device)
@classmethod
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
if cls.has_device_capability(80) and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
return _Backend.FLASH_ATTN
logger.warning_once(
"Current `vllm-flash-attn` has a bug inside vision "
"module, so we use xformers backend instead. You can "
"run `pip install flash-attn` to use flash-attention "
"backend.")
# Fallback for Volta/Turing GPUs or FA not supported
return _Backend.XFORMERS
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, kv_cache_dtype, block_size, use_v1,

View File

@ -46,6 +46,7 @@ class _Backend(enum.Enum):
ROCM_FLASH = enum.auto() ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto() # Supported by V1 ROCM_AITER_MLA = enum.auto() # Supported by V1
ROCM_AITER_MLA_VLLM_V1 = enum.auto() ROCM_AITER_MLA_VLLM_V1 = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto() TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto() FLASHINFER = enum.auto()
FLASHINFER_VLLM_V1 = enum.auto() FLASHINFER_VLLM_V1 = enum.auto()
@ -186,6 +187,10 @@ class Platform:
else: else:
return device_id return device_id
@classmethod
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
return _Backend.TORCH_SDPA
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],

View File

@ -173,6 +173,18 @@ class RocmPlatform(Platform):
"quark", "ptpc_fp8" "quark", "ptpc_fp8"
] ]
@classmethod
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
if support_fa:
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
and on_gfx9()):
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return _Backend.ROCM_AITER_FA
if on_gfx9():
return _Backend.FLASH_ATTN
return _Backend.TORCH_SDPA
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, kv_cache_dtype, block_size, use_v1,