mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 08:45:01 +08:00
[BUGFIX][ROCM] ViT FlashAttention on ROCm (no GFX9) and contiguous on qwen3vl ROCm TORCH_SDPA (#27190)
Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
d63cd9ff10
commit
65d2cf9511
@ -47,6 +47,12 @@ from vllm.v1.kv_cache_interface import (
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
else:
|
||||
on_gfx9 = lambda *args, **kwargs: False
|
||||
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
logger = init_logger(__name__)
|
||||
USE_XFORMERS_OPS = None
|
||||
@ -96,18 +102,29 @@ def maybe_get_vit_flash_attn_backend(
|
||||
attn_backend: _Backend,
|
||||
use_upstream_fa: bool,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> tuple[_Backend, Callable]:
|
||||
if (
|
||||
attn_backend != _Backend.FLASH_ATTN
|
||||
and attn_backend != _Backend.ROCM_AITER_FA
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
and attn_backend_override is None
|
||||
):
|
||||
attn_backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
) -> tuple[_Backend, Callable | None]:
|
||||
if current_platform.is_rocm():
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
attn_backend = _Backend.ROCM_AITER_FA
|
||||
|
||||
if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN:
|
||||
use_upstream_fa = True
|
||||
elif (
|
||||
check_upstream_fa_availability(torch.get_default_dtype())
|
||||
and on_gfx9()
|
||||
and attn_backend_override is None
|
||||
):
|
||||
attn_backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
else:
|
||||
return _Backend.TORCH_SDPA, None
|
||||
|
||||
elif current_platform.is_cuda():
|
||||
if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
):
|
||||
attn_backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
else:
|
||||
return _Backend.TORCH_SDPA, None
|
||||
|
||||
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||
if attn_backend == _Backend.ROCM_AITER_FA:
|
||||
@ -570,6 +587,7 @@ class MultiHeadAttention(nn.Module):
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
assert self._flash_attn_varlen_func is not None
|
||||
cu_seqlens_q = torch.arange(
|
||||
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
|
||||
)
|
||||
|
||||
@ -429,6 +429,12 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
|
||||
@ -462,6 +462,12 @@ class Qwen2VisionAttention(nn.Module):
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
|
||||
@ -205,12 +205,16 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
||||
from importlib.util import find_spec
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
return _Backend.ROCM_AITER_FA
|
||||
if on_gfx9():
|
||||
|
||||
if on_gfx9() and find_spec("flash_attn") is not None:
|
||||
return _Backend.FLASH_ATTN
|
||||
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user