[BUGFIX][ROCM] ViT FlashAttention on ROCm (no GFX9) and contiguous on qwen3vl ROCm TORCH_SDPA (#27190)

Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
JartX 2025-10-26 08:08:52 +01:00 committed by GitHub
parent d63cd9ff10
commit 65d2cf9511
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 46 additions and 12 deletions

View File

@ -47,6 +47,12 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec,
)
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9
else:
on_gfx9 = lambda *args, **kwargs: False
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
@ -96,18 +102,29 @@ def maybe_get_vit_flash_attn_backend(
attn_backend: _Backend,
use_upstream_fa: bool,
attn_backend_override: _Backend | None = None,
) -> tuple[_Backend, Callable]:
if (
attn_backend != _Backend.FLASH_ATTN
and attn_backend != _Backend.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype())
and attn_backend_override is None
):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
) -> tuple[_Backend, Callable | None]:
if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = _Backend.ROCM_AITER_FA
if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN:
use_upstream_fa = True
elif (
check_upstream_fa_availability(torch.get_default_dtype())
and on_gfx9()
and attn_backend_override is None
):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
else:
return _Backend.TORCH_SDPA, None
elif current_platform.is_cuda():
if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
else:
return _Backend.TORCH_SDPA, None
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
if attn_backend == _Backend.ROCM_AITER_FA:
@ -570,6 +587,7 @@ class MultiHeadAttention(nn.Module):
value = torch.repeat_interleave(value, num_repeat, dim=2)
if self.is_flash_attn_backend:
assert self._flash_attn_varlen_func is not None
cu_seqlens_q = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
)

View File

@ -429,6 +429,12 @@ class Qwen2_5_VisionAttention(nn.Module):
).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]

View File

@ -462,6 +462,12 @@ class Qwen2VisionAttention(nn.Module):
).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]

View File

@ -205,12 +205,16 @@ class RocmPlatform(Platform):
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
from importlib.util import find_spec
from vllm.attention.backends.registry import _Backend
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
return _Backend.ROCM_AITER_FA
if on_gfx9():
if on_gfx9() and find_spec("flash_attn") is not None:
return _Backend.FLASH_ATTN
return _Backend.TORCH_SDPA
@classmethod