diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 7544daa3aff7a..22eaa22b8b385 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -47,6 +47,12 @@ from vllm.v1.kv_cache_interface import ( SlidingWindowSpec, ) +if current_platform.is_rocm(): + from vllm.platforms.rocm import on_gfx9 +else: + on_gfx9 = lambda *args, **kwargs: False + + FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) USE_XFORMERS_OPS = None @@ -96,18 +102,29 @@ def maybe_get_vit_flash_attn_backend( attn_backend: _Backend, use_upstream_fa: bool, attn_backend_override: _Backend | None = None, -) -> tuple[_Backend, Callable]: - if ( - attn_backend != _Backend.FLASH_ATTN - and attn_backend != _Backend.ROCM_AITER_FA - and check_upstream_fa_availability(torch.get_default_dtype()) - and attn_backend_override is None - ): - attn_backend = _Backend.FLASH_ATTN - use_upstream_fa = True +) -> tuple[_Backend, Callable | None]: + if current_platform.is_rocm(): + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): + attn_backend = _Backend.ROCM_AITER_FA - if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN: - use_upstream_fa = True + elif ( + check_upstream_fa_availability(torch.get_default_dtype()) + and on_gfx9() + and attn_backend_override is None + ): + attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + else: + return _Backend.TORCH_SDPA, None + + elif current_platform.is_cuda(): + if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): + attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + else: + return _Backend.TORCH_SDPA, None if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: if attn_backend == _Backend.ROCM_AITER_FA: @@ -570,6 +587,7 @@ class MultiHeadAttention(nn.Module): value = torch.repeat_interleave(value, num_repeat, dim=2) if self.is_flash_attn_backend: + assert self._flash_attn_varlen_func is not None cu_seqlens_q = torch.arange( 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device ) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index c657b06d43553..a3436201a1db6 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -429,6 +429,12 @@ class Qwen2_5_VisionAttention(nn.Module): ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. + from vllm.platforms import current_platform + + if current_platform.is_rocm(): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() outputs = [] for i in range(1, len(cu_seqlens)): start_idx = cu_seqlens[i - 1] diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 61f7970d56f60..47ce3ee744edd 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -462,6 +462,12 @@ class Qwen2VisionAttention(nn.Module): ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. + from vllm.platforms import current_platform + + if current_platform.is_rocm(): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() outputs = [] for i in range(1, len(cu_seqlens)): start_idx = cu_seqlens[i - 1] diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b2ec40849446d..059ed4430e367 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -205,12 +205,16 @@ class RocmPlatform(Platform): @classmethod def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + from importlib.util import find_spec + from vllm.attention.backends.registry import _Backend if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): return _Backend.ROCM_AITER_FA - if on_gfx9(): + + if on_gfx9() and find_spec("flash_attn") is not None: return _Backend.FLASH_ATTN + return _Backend.TORCH_SDPA @classmethod