From f35f896e3a8a1d50a0a82c8908e11e0e751a7b8e Mon Sep 17 00:00:00 2001 From: TJian Date: Thu, 2 Oct 2025 22:34:53 -0700 Subject: [PATCH] [ROCm] [VL] [Bugfix] Fix vit flash attn dispatcher logic for ROCm (#26104) Signed-off-by: tjtanaa Signed-off-by: yewentao256 --- vllm/attention/layer.py | 72 +++++++++++++++------- vllm/model_executor/models/dots_ocr.py | 41 ++++++------ vllm/model_executor/models/ernie45_vl.py | 49 +++++++-------- vllm/model_executor/models/glm4_1v.py | 31 +++++----- vllm/model_executor/models/qwen2_5_vl.py | 34 +++++----- vllm/model_executor/models/qwen2_vl.py | 40 ++++++------ vllm/model_executor/models/qwen3_vl.py | 4 +- vllm/model_executor/models/siglip2navit.py | 22 +++---- vllm/platforms/rocm.py | 2 - 9 files changed, 154 insertions(+), 141 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 113602645e898..ac34f279d0b57 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -from typing import List, Optional +from typing import Callable, List, Optional import torch import torch.nn as nn @@ -68,9 +68,39 @@ def check_upstream_fa_availability(dtype: torch.dtype): ) and current_platform.has_device_capability(80): from transformers.utils import is_flash_attn_2_available return is_flash_attn_2_available() + if current_platform.is_rocm(): + from importlib.util import find_spec + return find_spec("flash_attn") is not None return False +def maybe_get_vit_flash_attn_backend( + attn_backend: _Backend, + use_upstream_fa: bool) -> tuple[_Backend, Callable]: + if attn_backend != _Backend.FLASH_ATTN and \ + attn_backend != _Backend.ROCM_AITER_FA and \ + check_upstream_fa_availability(torch.get_default_dtype()): + attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + + if current_platform.is_rocm() and \ + attn_backend == _Backend.FLASH_ATTN: + use_upstream_fa = True + + if (attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}): + if attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + if use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func + else: + flash_attn_varlen_func = None + + return attn_backend, flash_attn_varlen_func + + class Attention(nn.Module, AttentionLayerBase): """Attention layer. @@ -410,13 +440,9 @@ class MultiHeadAttention(nn.Module): # to upstream flash attention if available. # If vllm native fa is selected, we use it directly. use_upstream_fa = False - if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - dtype): - backend = _Backend.FLASH_ATTN - use_upstream_fa = True - if current_platform.is_rocm() or current_platform.is_xpu(): - # currently, only torch_sdpa is supported on rocm/xpu + if current_platform.is_xpu(): + # currently, only torch_sdpa is supported on xpu self.attn_backend = _Backend.TORCH_SDPA else: @@ -428,17 +454,25 @@ class MultiHeadAttention(nn.Module): _Backend.FLASH_ATTN, } else _Backend.TORCH_SDPA + self.attn_backend, self._flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + use_upstream_fa, + ) + if (self.attn_backend == _Backend.XFORMERS and not check_xformers_availability()): self.attn_backend = _Backend.TORCH_SDPA - if self.attn_backend == _Backend.FLASH_ATTN: - if use_upstream_fa: - from flash_attn import flash_attn_varlen_func - self._flash_attn_varlen_func = flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func - self._flash_attn_varlen_func = flash_attn_varlen_func + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + + # this condition is just to make sure that the + # use_upstream_fa in the log is correct + if current_platform.is_rocm() \ + and self.attn_backend == _Backend.FLASH_ATTN: + use_upstream_fa = True logger.info_once( f"MultiHeadAttention attn_backend: {self.attn_backend}, " @@ -466,7 +500,7 @@ class MultiHeadAttention(nn.Module): key = torch.repeat_interleave(key, num_repeat, dim=2) value = torch.repeat_interleave(value, num_repeat, dim=2) - if self.attn_backend == _Backend.FLASH_ATTN: + if self.is_flash_attn_backend: cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, @@ -507,14 +541,6 @@ class MultiHeadAttention(nn.Module): from torch_xla.experimental.custom_kernel import flash_attention out = flash_attention(query, key, value, sm_scale=self.scale) out = out.transpose(1, 2) - elif self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - - # ROCm Flash Attention expects (batch, seq, heads, head_dim) - out = flash_attn_varlen_func(query, - key, - value, - softmax_scale=self.scale) else: # ViT attention hasn't supported this backend yet raise NotImplementedError( diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 2445f0d784f44..86888c10ee398 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -10,7 +10,8 @@ from torch.nn import LayerNorm from transformers.models.qwen2_vl import Qwen2VLProcessor from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import (check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend) from vllm.config import VllmConfig from vllm.distributed import utils as dist_utils from vllm.distributed.parallel_state import ( @@ -267,10 +268,12 @@ class DotsVisionAttention(nn.Module): self.attn_backend = get_vit_attn_backend( self.hidden_size_per_attention_head, torch.get_default_dtype()) self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability(torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + + self.attn_backend, self.flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -306,25 +309,18 @@ class DotsVisionAttention(nn.Module): q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3]) k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3]) v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3]) - output = flash_attn_varlen_func(q_, - k_, - v_, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) + output = self.flash_attn_varlen_func(q_, + k_, + v_, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) context_layer = output.view(bs, -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) @@ -611,7 +607,8 @@ class DotsVisionTransformer(nn.Module): self, cu_seqlens: torch.Tensor ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if (self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 0b8e24407602d..8da7b9f2c3e09 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -35,7 +35,8 @@ from einops import rearrange, repeat from transformers import BatchFeature from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import (check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend) from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -176,14 +177,18 @@ class Ernie4_5_VisionAttention(nn.Module): dtype=torch.get_default_dtype()) self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability(torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + + self.attn_backend, self.flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( f"Ernie45-VL does not support {self.attn_backend} backend now." @@ -239,27 +244,18 @@ class Ernie4_5_VisionAttention(nn.Module): q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) + output = self.flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) context_layer = rearrange(output, "(b s) h d -> s b (h d)", @@ -516,7 +512,8 @@ class Ernie4_5_VisionTransformer(nn.Module): self, cu_seqlens: torch.Tensor ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if (self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 315a057e6a7d6..e6e294a143493 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -47,7 +47,8 @@ from transformers.models.glm4v.video_processing_glm4v import ( from transformers.video_utils import VideoMetadata from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import (check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend) from vllm.config import VllmConfig from vllm.distributed import (get_tensor_model_parallel_world_size, parallel_state) @@ -263,19 +264,26 @@ class Glm4vVisionAttention(nn.Module): head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype()) self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability(torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + + self.attn_backend, self.flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( f"GLM-4V does not support {self.attn_backend} backend now.") + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape @@ -316,17 +324,11 @@ class Glm4vVisionAttention(nn.Module): qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.attn_backend == _Backend.FLASH_ATTN: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func + if self.is_flash_attn_backend: q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func( + output = self.flash_attn_varlen_func( q, k, v, @@ -774,7 +776,8 @@ class Glm4vVisionTransformer(nn.Module): ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - if self.attn_backend == _Backend.FLASH_ATTN: + if (self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index a70df3b72be48..3c46516c7905f 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -39,7 +39,8 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import (check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend) from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -302,6 +303,11 @@ class Qwen2_5_VisionAttention(nn.Module): disable_tp=use_data_parallel) self.attn_backend = attn_backend self.use_upstream_fa = use_upstream_fa + self.attn_backend, self.flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) self.is_flash_attn_backend = self.attn_backend in { _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA } @@ -354,25 +360,18 @@ class Qwen2_5_VisionAttention(nn.Module): q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) + output = self.flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) context_layer = rearrange(output, "(b s) h d -> s b (h d)", @@ -618,6 +617,7 @@ class Qwen2_5_VisionTransformer(nn.Module): self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) if self.attn_backend != _Backend.FLASH_ATTN and \ + self.attn_backend != _Backend.ROCM_AITER_FA and \ check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 2ff79765d4be2..48dec351bd902 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -42,7 +42,8 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( Qwen2VLVideoProcessor) from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import (check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend) from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils @@ -319,11 +320,12 @@ class Qwen2VisionAttention(nn.Module): head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype()) self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + + self.attn_backend, self.flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, @@ -331,6 +333,7 @@ class Qwen2VisionAttention(nn.Module): }: raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now.") + self.is_flash_attn_backend = self.attn_backend in { _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA } @@ -383,25 +386,18 @@ class Qwen2VisionAttention(nn.Module): q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) + output = self.flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) context_layer = rearrange(output, "(b s) h d -> s b (h d)", diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index fc8557131c3e8..da6ca7940700f 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -323,6 +323,7 @@ class Qwen3_VisionTransformer(nn.Module): head_size=head_dim, dtype=torch.get_default_dtype()) use_upstream_fa = False if self.attn_backend != _Backend.FLASH_ATTN and \ + self.attn_backend != _Backend.ROCM_AITER_FA and \ check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN @@ -476,7 +477,8 @@ class Qwen3_VisionTransformer(nn.Module): cu_seqlens: torch.Tensor, ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if (self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index d111a10809e77..5bea5b1daf4d6 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -14,7 +14,7 @@ from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -240,11 +240,12 @@ class Siglip2Attention(nn.Module): self.attn_backend = get_vit_attn_backend( head_size=self.head_dim, dtype=torch.get_default_dtype()) self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + + self.attn_backend, self.flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, @@ -286,14 +287,7 @@ class Siglip2Attention(nn.Module): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func - attn_output = flash_attn_varlen_func( + attn_output = self.flash_attn_varlen_func( queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(seq_length, -1) elif self.attn_backend == _Backend.TORCH_SDPA: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index e12967ad25870..de3df03d1fa06 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -189,8 +189,6 @@ class RocmPlatform(Platform): from vllm.attention.backends.registry import _Backend if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()): - # Note: AITER FA is only supported for Qwen-VL models. - # TODO: Add support for other VL models in their model class. return _Backend.ROCM_AITER_FA if on_gfx9(): return _Backend.FLASH_ATTN