mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 01:25:01 +08:00
[BUGFIX][ROCM] ViT FlashAttention on ROCm (no GFX9) and contiguous on qwen3vl ROCm TORCH_SDPA (#27190)
Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
d63cd9ff10
commit
65d2cf9511
@ -47,6 +47,12 @@ from vllm.v1.kv_cache_interface import (
|
|||||||
SlidingWindowSpec,
|
SlidingWindowSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
from vllm.platforms.rocm import on_gfx9
|
||||||
|
else:
|
||||||
|
on_gfx9 = lambda *args, **kwargs: False
|
||||||
|
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
USE_XFORMERS_OPS = None
|
USE_XFORMERS_OPS = None
|
||||||
@ -96,18 +102,29 @@ def maybe_get_vit_flash_attn_backend(
|
|||||||
attn_backend: _Backend,
|
attn_backend: _Backend,
|
||||||
use_upstream_fa: bool,
|
use_upstream_fa: bool,
|
||||||
attn_backend_override: _Backend | None = None,
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> tuple[_Backend, Callable]:
|
) -> tuple[_Backend, Callable | None]:
|
||||||
if (
|
if current_platform.is_rocm():
|
||||||
attn_backend != _Backend.FLASH_ATTN
|
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||||
and attn_backend != _Backend.ROCM_AITER_FA
|
attn_backend = _Backend.ROCM_AITER_FA
|
||||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
|
||||||
|
elif (
|
||||||
|
check_upstream_fa_availability(torch.get_default_dtype())
|
||||||
|
and on_gfx9()
|
||||||
and attn_backend_override is None
|
and attn_backend_override is None
|
||||||
):
|
):
|
||||||
attn_backend = _Backend.FLASH_ATTN
|
attn_backend = _Backend.FLASH_ATTN
|
||||||
use_upstream_fa = True
|
use_upstream_fa = True
|
||||||
|
else:
|
||||||
|
return _Backend.TORCH_SDPA, None
|
||||||
|
|
||||||
if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN:
|
elif current_platform.is_cuda():
|
||||||
|
if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||||
|
torch.get_default_dtype()
|
||||||
|
):
|
||||||
|
attn_backend = _Backend.FLASH_ATTN
|
||||||
use_upstream_fa = True
|
use_upstream_fa = True
|
||||||
|
else:
|
||||||
|
return _Backend.TORCH_SDPA, None
|
||||||
|
|
||||||
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||||
if attn_backend == _Backend.ROCM_AITER_FA:
|
if attn_backend == _Backend.ROCM_AITER_FA:
|
||||||
@ -570,6 +587,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||||
|
|
||||||
if self.is_flash_attn_backend:
|
if self.is_flash_attn_backend:
|
||||||
|
assert self._flash_attn_varlen_func is not None
|
||||||
cu_seqlens_q = torch.arange(
|
cu_seqlens_q = torch.arange(
|
||||||
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
|
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
|
||||||
)
|
)
|
||||||
|
|||||||
@ -429,6 +429,12 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
).contiguous()
|
).contiguous()
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
# Execute attention entry by entry for speed & less VRAM.
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
outputs = []
|
outputs = []
|
||||||
for i in range(1, len(cu_seqlens)):
|
for i in range(1, len(cu_seqlens)):
|
||||||
start_idx = cu_seqlens[i - 1]
|
start_idx = cu_seqlens[i - 1]
|
||||||
|
|||||||
@ -462,6 +462,12 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
).contiguous()
|
).contiguous()
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
# Execute attention entry by entry for speed & less VRAM.
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
outputs = []
|
outputs = []
|
||||||
for i in range(1, len(cu_seqlens)):
|
for i in range(1, len(cu_seqlens)):
|
||||||
start_idx = cu_seqlens[i - 1]
|
start_idx = cu_seqlens[i - 1]
|
||||||
|
|||||||
@ -205,12 +205,16 @@ class RocmPlatform(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
|
|
||||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||||
return _Backend.ROCM_AITER_FA
|
return _Backend.ROCM_AITER_FA
|
||||||
if on_gfx9():
|
|
||||||
|
if on_gfx9() and find_spec("flash_attn") is not None:
|
||||||
return _Backend.FLASH_ATTN
|
return _Backend.FLASH_ATTN
|
||||||
|
|
||||||
return _Backend.TORCH_SDPA
|
return _Backend.TORCH_SDPA
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user