[ROCm] [VL] [Bugfix] Fix vit flash attn dispatcher logic for ROCm (#26104)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
TJian 2025-10-02 22:34:53 -07:00 committed by yewentao256
parent 218349d760
commit f35f896e3a
9 changed files with 154 additions and 141 deletions

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from typing import List, Optional
from typing import Callable, List, Optional
import torch
import torch.nn as nn
@ -68,9 +68,39 @@ def check_upstream_fa_availability(dtype: torch.dtype):
) and current_platform.has_device_capability(80):
from transformers.utils import is_flash_attn_2_available
return is_flash_attn_2_available()
if current_platform.is_rocm():
from importlib.util import find_spec
return find_spec("flash_attn") is not None
return False
def maybe_get_vit_flash_attn_backend(
attn_backend: _Backend,
use_upstream_fa: bool) -> tuple[_Backend, Callable]:
if attn_backend != _Backend.FLASH_ATTN and \
attn_backend != _Backend.ROCM_AITER_FA and \
check_upstream_fa_availability(torch.get_default_dtype()):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if current_platform.is_rocm() and \
attn_backend == _Backend.FLASH_ATTN:
use_upstream_fa = True
if (attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}):
if attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
return attn_backend, flash_attn_varlen_func
class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.
@ -410,13 +440,9 @@ class MultiHeadAttention(nn.Module):
# to upstream flash attention if available.
# If vllm native fa is selected, we use it directly.
use_upstream_fa = False
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
dtype):
backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if current_platform.is_rocm() or current_platform.is_xpu():
# currently, only torch_sdpa is supported on rocm/xpu
if current_platform.is_xpu():
# currently, only torch_sdpa is supported on xpu
self.attn_backend = _Backend.TORCH_SDPA
else:
@ -428,17 +454,25 @@ class MultiHeadAttention(nn.Module):
_Backend.FLASH_ATTN,
} else _Backend.TORCH_SDPA
self.attn_backend, self._flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa,
)
if (self.attn_backend == _Backend.XFORMERS
and not check_xformers_availability()):
self.attn_backend = _Backend.TORCH_SDPA
if self.attn_backend == _Backend.FLASH_ATTN:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
# this condition is just to make sure that the
# use_upstream_fa in the log is correct
if current_platform.is_rocm() \
and self.attn_backend == _Backend.FLASH_ATTN:
use_upstream_fa = True
logger.info_once(
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
@ -466,7 +500,7 @@ class MultiHeadAttention(nn.Module):
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
if self.attn_backend == _Backend.FLASH_ATTN:
if self.is_flash_attn_backend:
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
@ -507,14 +541,6 @@ class MultiHeadAttention(nn.Module):
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
out = flash_attn_varlen_func(query,
key,
value,
softmax_scale=self.scale)
else:
# ViT attention hasn't supported this backend yet
raise NotImplementedError(

View File

@ -10,7 +10,8 @@ from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
@ -267,10 +268,12 @@ class DotsVisionAttention(nn.Module):
self.attn_backend = get_vit_attn_backend(
self.hidden_size_per_attention_head, torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
@ -306,25 +309,18 @@ class DotsVisionAttention(nn.Module):
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
output = flash_attn_varlen_func(q_,
k_,
v_,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
output = self.flash_attn_varlen_func(q_,
k_,
v_,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
context_layer = output.view(bs, -1,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
@ -611,7 +607,8 @@ class DotsVisionTransformer(nn.Module):
self, cu_seqlens: torch.Tensor
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN:
if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

View File

@ -35,7 +35,8 @@ from einops import rearrange, repeat
from transformers import BatchFeature
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
@ -176,14 +177,18 @@ class Ernie4_5_VisionAttention(nn.Module):
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Ernie45-VL does not support {self.attn_backend} backend now."
@ -239,27 +244,18 @@ class Ernie4_5_VisionAttention(nn.Module):
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
output = self.flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
context_layer = rearrange(output,
"(b s) h d -> s b (h d)",
@ -516,7 +512,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
self, cu_seqlens: torch.Tensor
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN:
if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

View File

@ -47,7 +47,8 @@ from transformers.models.glm4v.video_processing_glm4v import (
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
parallel_state)
@ -263,19 +264,26 @@ class Glm4vVisionAttention(nn.Module):
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}:
raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now.")
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@ -316,17 +324,11 @@ class Glm4vVisionAttention(nn.Module):
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
output = self.flash_attn_varlen_func(
q,
k,
v,
@ -774,7 +776,8 @@ class Glm4vVisionTransformer(nn.Module):
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if self.attn_backend == _Backend.FLASH_ATTN:
if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens

View File

@ -39,7 +39,8 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
@ -302,6 +303,11 @@ class Qwen2_5_VisionAttention(nn.Module):
disable_tp=use_data_parallel)
self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa
self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
@ -354,25 +360,18 @@ class Qwen2_5_VisionAttention(nn.Module):
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
output = self.flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
context_layer = rearrange(output,
"(b s) h d -> s b (h d)",
@ -618,6 +617,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
self.attn_backend != _Backend.ROCM_AITER_FA and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN

View File

@ -42,7 +42,8 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
Qwen2VLVideoProcessor)
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils
@ -319,11 +320,12 @@ class Qwen2VisionAttention(nn.Module):
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
@ -331,6 +333,7 @@ class Qwen2VisionAttention(nn.Module):
}:
raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now.")
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
@ -383,25 +386,18 @@ class Qwen2VisionAttention(nn.Module):
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
output = self.flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False)
context_layer = rearrange(output,
"(b s) h d -> s b (h d)",

View File

@ -323,6 +323,7 @@ class Qwen3_VisionTransformer(nn.Module):
head_size=head_dim, dtype=torch.get_default_dtype())
use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
self.attn_backend != _Backend.ROCM_AITER_FA and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
@ -476,7 +477,8 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN:
if (self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

View File

@ -14,7 +14,7 @@ from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -240,11 +240,12 @@ class Siglip2Attention(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=self.head_dim, dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
self.attn_backend, self.flash_attn_varlen_func \
= maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
@ -286,14 +287,7 @@ class Siglip2Attention(nn.Module):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
attn_output = flash_attn_varlen_func(
attn_output = self.flash_attn_varlen_func(
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
max_seqlen).reshape(seq_length, -1)
elif self.attn_backend == _Backend.TORCH_SDPA:

View File

@ -189,8 +189,6 @@ class RocmPlatform(Platform):
from vllm.attention.backends.registry import _Backend
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
and on_gfx9()):
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return _Backend.ROCM_AITER_FA
if on_gfx9():
return _Backend.FLASH_ATTN