mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 12:05:36 +08:00
[ROCm] [VL] [Bugfix] Fix vit flash attn dispatcher logic for ROCm (#26104)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
218349d760
commit
f35f896e3a
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer."""
|
||||
from typing import List, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -68,9 +68,39 @@ def check_upstream_fa_availability(dtype: torch.dtype):
|
||||
) and current_platform.has_device_capability(80):
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
return is_flash_attn_2_available()
|
||||
if current_platform.is_rocm():
|
||||
from importlib.util import find_spec
|
||||
return find_spec("flash_attn") is not None
|
||||
return False
|
||||
|
||||
|
||||
def maybe_get_vit_flash_attn_backend(
|
||||
attn_backend: _Backend,
|
||||
use_upstream_fa: bool) -> tuple[_Backend, Callable]:
|
||||
if attn_backend != _Backend.FLASH_ATTN and \
|
||||
attn_backend != _Backend.ROCM_AITER_FA and \
|
||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
attn_backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
|
||||
if current_platform.is_rocm() and \
|
||||
attn_backend == _Backend.FLASH_ATTN:
|
||||
use_upstream_fa = True
|
||||
|
||||
if (attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}):
|
||||
if attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
if use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
return attn_backend, flash_attn_varlen_func
|
||||
|
||||
|
||||
class Attention(nn.Module, AttentionLayerBase):
|
||||
"""Attention layer.
|
||||
|
||||
@ -410,13 +440,9 @@ class MultiHeadAttention(nn.Module):
|
||||
# to upstream flash attention if available.
|
||||
# If vllm native fa is selected, we use it directly.
|
||||
use_upstream_fa = False
|
||||
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
dtype):
|
||||
backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
# currently, only torch_sdpa is supported on rocm/xpu
|
||||
if current_platform.is_xpu():
|
||||
# currently, only torch_sdpa is supported on xpu
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
else:
|
||||
|
||||
@ -428,17 +454,25 @@ class MultiHeadAttention(nn.Module):
|
||||
_Backend.FLASH_ATTN,
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
self.attn_backend, self._flash_attn_varlen_func \
|
||||
= maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
use_upstream_fa,
|
||||
)
|
||||
|
||||
if (self.attn_backend == _Backend.XFORMERS
|
||||
and not check_xformers_availability()):
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||
}
|
||||
|
||||
# this condition is just to make sure that the
|
||||
# use_upstream_fa in the log is correct
|
||||
if current_platform.is_rocm() \
|
||||
and self.attn_backend == _Backend.FLASH_ATTN:
|
||||
use_upstream_fa = True
|
||||
|
||||
logger.info_once(
|
||||
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
|
||||
@ -466,7 +500,7 @@ class MultiHeadAttention(nn.Module):
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if self.is_flash_attn_backend:
|
||||
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
@ -507,14 +541,6 @@ class MultiHeadAttention(nn.Module):
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
|
||||
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
|
||||
out = flash_attn_varlen_func(query,
|
||||
key,
|
||||
value,
|
||||
softmax_scale=self.scale)
|
||||
else:
|
||||
# ViT attention hasn't supported this backend yet
|
||||
raise NotImplementedError(
|
||||
|
||||
@ -10,7 +10,8 @@ from torch.nn import LayerNorm
|
||||
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.distributed.parallel_state import (
|
||||
@ -267,10 +268,12 @@ class DotsVisionAttention(nn.Module):
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
self.hidden_size_per_attention_head, torch.get_default_dtype())
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func \
|
||||
= maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA
|
||||
@ -306,25 +309,18 @@ class DotsVisionAttention(nn.Module):
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
if self.use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
|
||||
k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
|
||||
v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
|
||||
output = flash_attn_varlen_func(q_,
|
||||
k_,
|
||||
v_,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False)
|
||||
output = self.flash_attn_varlen_func(q_,
|
||||
k_,
|
||||
v_,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False)
|
||||
context_layer = output.view(bs, -1,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
@ -611,7 +607,8 @@ class DotsVisionTransformer(nn.Module):
|
||||
self, cu_seqlens: torch.Tensor
|
||||
) -> tuple[Optional[int], Optional[list[int]]]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if (self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
|
||||
@ -35,7 +35,8 @@ from einops import rearrange, repeat
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -176,14 +177,18 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
dtype=torch.get_default_dtype())
|
||||
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func \
|
||||
= maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Ernie45-VL does not support {self.attn_backend} backend now."
|
||||
@ -239,27 +244,18 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
# from vllm_flash_attn.flash_attn_interface import (
|
||||
# flash_attn_varlen_func)
|
||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
if self.use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
output = flash_attn_varlen_func(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False)
|
||||
output = self.flash_attn_varlen_func(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False)
|
||||
|
||||
context_layer = rearrange(output,
|
||||
"(b s) h d -> s b (h d)",
|
||||
@ -516,7 +512,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
self, cu_seqlens: torch.Tensor
|
||||
) -> tuple[Optional[int], Optional[list[int]]]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if (self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
|
||||
@ -47,7 +47,8 @@ from transformers.models.glm4v.video_processing_glm4v import (
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
parallel_state)
|
||||
@ -263,19 +264,26 @@ class Glm4vVisionAttention(nn.Module):
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype())
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func \
|
||||
= maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"GLM-4V does not support {self.attn_backend} backend now.")
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@ -316,17 +324,11 @@ class Glm4vVisionAttention(nn.Module):
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
# from vllm_flash_attn.flash_attn_interface import (
|
||||
# flash_attn_varlen_func)
|
||||
if self.use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
if self.is_flash_attn_backend:
|
||||
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
output = flash_attn_varlen_func(
|
||||
output = self.flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -774,7 +776,8 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
) -> tuple[Optional[int], Optional[list[int]]]:
|
||||
max_seqlen, seqlens = None, None
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if (self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
return max_seqlen, seqlens
|
||||
|
||||
|
||||
@ -39,7 +39,8 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -302,6 +303,11 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
disable_tp=use_data_parallel)
|
||||
self.attn_backend = attn_backend
|
||||
self.use_upstream_fa = use_upstream_fa
|
||||
self.attn_backend, self.flash_attn_varlen_func \
|
||||
= maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||
}
|
||||
@ -354,25 +360,18 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
if self.use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
output = flash_attn_varlen_func(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False)
|
||||
output = self.flash_attn_varlen_func(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False)
|
||||
|
||||
context_layer = rearrange(output,
|
||||
"(b s) h d -> s b (h d)",
|
||||
@ -618,6 +617,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
self.attn_backend != _Backend.ROCM_AITER_FA and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
|
||||
@ -42,7 +42,8 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
|
||||
Qwen2VLVideoProcessor)
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.attention.layer import (check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||
from vllm.distributed import utils as dist_utils
|
||||
@ -319,11 +320,12 @@ class Qwen2VisionAttention(nn.Module):
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype())
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func \
|
||||
= maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
@ -331,6 +333,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2-VL does not support {self.attn_backend} backend now.")
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||
}
|
||||
@ -383,25 +386,18 @@ class Qwen2VisionAttention(nn.Module):
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
if self.use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
output = flash_attn_varlen_func(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False)
|
||||
output = self.flash_attn_varlen_func(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False)
|
||||
|
||||
context_layer = rearrange(output,
|
||||
"(b s) h d -> s b (h d)",
|
||||
|
||||
@ -323,6 +323,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||
use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
self.attn_backend != _Backend.ROCM_AITER_FA and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
@ -476,7 +477,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> tuple[Optional[int], Optional[list[int]]]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if (self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
|
||||
@ -14,7 +14,7 @@ from transformers import Siglip2VisionConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -240,11 +240,12 @@ class Siglip2Attention(nn.Module):
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.head_dim, dtype=torch.get_default_dtype())
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func \
|
||||
= maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
|
||||
@ -286,14 +287,7 @@ class Siglip2Attention(nn.Module):
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
if self.is_flash_attn_backend:
|
||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
if self.use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
attn_output = flash_attn_varlen_func(
|
||||
attn_output = self.flash_attn_varlen_func(
|
||||
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
|
||||
max_seqlen).reshape(seq_length, -1)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
|
||||
@ -189,8 +189,6 @@ class RocmPlatform(Platform):
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
|
||||
and on_gfx9()):
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
# TODO: Add support for other VL models in their model class.
|
||||
return _Backend.ROCM_AITER_FA
|
||||
if on_gfx9():
|
||||
return _Backend.FLASH_ATTN
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user