From 6cbd41909edaee82786b5a896ccc2fbccd934aa8 Mon Sep 17 00:00:00 2001 From: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com> Date: Wed, 10 Sep 2025 09:10:14 -0400 Subject: [PATCH] Feature/vit attention unification# 23880 (#23978) Signed-off-by: Isotr0py Co-authored-by: Isotr0py --- tests/kernels/attention/test_mha_attn.py | 22 ++++++++++++---- vllm/attention/layer.py | 25 ++++++++++++++----- .../models/idefics2_vision_model.py | 3 +++ vllm/model_executor/models/intern_vit.py | 11 ++++---- vllm/model_executor/models/interns1_vit.py | 17 ++++++------- vllm/model_executor/models/mllama.py | 20 ++++++--------- vllm/model_executor/models/step3_vl.py | 23 ++++++----------- vllm/model_executor/models/vision.py | 2 +- vllm/platforms/interface.py | 1 + 9 files changed, 68 insertions(+), 56 deletions(-) diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 53c37554b15a3..c01ea32994da0 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -23,6 +23,9 @@ def clear_cache(): """Clear lru cache to ensure each test case runs without caching. """ _cached_get_attn_backend.cache_clear() + # Clear xformers availability cache + import vllm.attention.layer as layer_module + layer_module.USE_XFORMERS_OPS = None @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) @@ -33,19 +36,28 @@ def test_mha_attn_platform(device: str): torch.set_default_dtype(torch.float16) if device == "cpu": - with patch("vllm.attention.selector.current_platform", CpuPlatform()): + with patch("vllm.attention.selector.current_platform", + CpuPlatform()), \ + patch("vllm.platforms.current_platform", CpuPlatform()): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.TORCH_SDPA + assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1 elif device == "hip": - with patch("vllm.attention.selector.current_platform", RocmPlatform()): + with patch("vllm.attention.selector.current_platform", + RocmPlatform()), \ + patch("vllm.platforms.current_platform", RocmPlatform()), \ + patch("vllm.attention.layer.current_platform", RocmPlatform()): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA else: - with patch("vllm.attention.selector.current_platform", CudaPlatform()): + with patch("vllm.attention.selector.current_platform", + CudaPlatform()), \ + patch("vllm.platforms.current_platform", CudaPlatform()): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.XFORMERS - with patch("vllm.attention.selector.current_platform", CudaPlatform()): + with patch("vllm.attention.selector.current_platform", + CudaPlatform()), \ + patch("vllm.platforms.current_platform", CudaPlatform()): attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == _Backend.XFORMERS diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 237802afccde9..be4dc3eb3c0da 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -360,13 +360,13 @@ class MultiHeadAttention(nn.Module): # currently, only torch_sdpa is supported on rocm self.attn_backend = _Backend.TORCH_SDPA else: - if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1, - _Backend.FLEX_ATTENTION): - backend = _Backend.XFORMERS - self.attn_backend = backend if backend in { - _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 - } else _Backend.TORCH_SDPA + _Backend.TORCH_SDPA, + _Backend.TORCH_SDPA_VLLM_V1, + _Backend.XFORMERS, + _Backend.PALLAS_VLLM_V1, + _Backend.ROCM_AITER_FA, + } else current_platform.get_vit_attn_backend() if (self.attn_backend == _Backend.XFORMERS and not check_xformers_availability()): @@ -413,6 +413,19 @@ class MultiHeadAttention(nn.Module): from torch_xla.experimental.custom_kernel import flash_attention out = flash_attention(query, key, value, sm_scale=self.scale) out = out.transpose(1, 2) + elif self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + + # ROCm Flash Attention expects (batch, seq, heads, head_dim) + out = flash_attn_varlen_func(query, + key, + value, + softmax_scale=self.scale) + else: + # ViT attention hasn't supported this backend yet + raise NotImplementedError( + f"ViT attention hasn't supported {self.attn_backend} " + f"backend yet.") return out.reshape(bsz, q_len, -1) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 0ca2e9e4bb688..ea5d6f29f6cfd 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -170,6 +170,7 @@ class Idefics2VisionAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.out_proj", ) + # Use unified MultiHeadAttention with Flash Attention support self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -181,6 +182,8 @@ class Idefics2VisionAttention(nn.Module): hidden_states ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim query_states, key_states, value_states = qkv.chunk(3, dim=-1) + + # Use unified MultiHeadAttention implementation out = self.attn(query_states, key_states, value_states) attn_output, _ = self.out_proj(out) return attn_output diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 58e8163e0b26e..8e9ab9649bd44 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -255,6 +255,10 @@ class InternSdpaAttention(nn.Module): self.proj = nn.Linear(self.dummy_dim, self.embed_dim) + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, + self.scale) + def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x) @@ -268,12 +272,9 @@ class InternSdpaAttention(nn.Module): B_, N_, H_, D_ = q.shape q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - x = x.transpose(1, 2).reshape(B, N, -1) + # Use unified MultiHeadAttention with automatic backend selection + x = self.attn(q, k, v) x = self.proj(x) return x diff --git a/vllm/model_executor/models/interns1_vit.py b/vllm/model_executor/models/interns1_vit.py index 300ed17ecaabc..eb6b685d03dc5 100644 --- a/vllm/model_executor/models/interns1_vit.py +++ b/vllm/model_executor/models/interns1_vit.py @@ -12,10 +12,10 @@ from typing import Optional import torch import torch.nn as nn -import torch.nn.functional as F from transformers import PretrainedConfig from transformers.utils import torch_int +from vllm.attention.layer import MultiHeadAttention from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -206,6 +206,10 @@ class InternSdpaAttention(nn.Module): self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim) + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, + self.scale) + def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape @@ -213,20 +217,13 @@ class InternSdpaAttention(nn.Module): k = self.k_proj(x) v = self.v_proj(x) - q = q.view(B, N, self.num_heads, self.head_dim) - k = k.view(B, N, self.num_heads, self.head_dim) - v = v.view(B, N, self.num_heads, self.head_dim) - if self.qk_normalization: B_, N_, H_, D_ = q.shape q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - x = x.transpose(1, 2).reshape(B, N, -1) + # Use unified MultiHeadAttention with automatic backend selection + x = self.attn(q, k, v) x = self.projection_layer(x) return x diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 68aa16f8b9ecf..048894085b360 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -35,6 +35,7 @@ from transformers.models.mllama.processing_mllama import ( import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention.layer import MultiHeadAttention from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.selector import _Backend from vllm.config import VllmConfig @@ -517,6 +518,10 @@ class MllamaVisionSdpaAttention(nn.Module): prefix=f"{prefix}.o_proj", ) + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim, + 1.0 / math.sqrt(self.head_dim)) + def forward( self, hidden_state: torch.Tensor, @@ -524,21 +529,10 @@ class MllamaVisionSdpaAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_state) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q.view(q.shape[0], q.shape[1], self.num_local_heads, - self.head_dim).transpose(1, 2) - k = k.view(k.shape[0], k.shape[1], self.num_local_heads, - self.head_dim).transpose(1, 2) - v = v.view(v.shape[0], v.shape[1], self.num_local_heads, - self.head_dim).transpose(1, 2) - # TODO: remove padding in image encoder - attn_output = F.scaled_dot_product_attention(q, - k, - v, - attn_mask=attention_mask, - dropout_p=0.0) + # Use unified MultiHeadAttention with automatic backend selection + attn_output = self.attn(q, k, v) - attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(attn_output.shape[0], attn_output.shape[1], -1) output, _ = self.o_proj(attn_output) diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 17299b64978e3..2ba5f94ea3b88 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -16,6 +16,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from transformers import BatchFeature, PretrainedConfig, TensorType +from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -682,9 +683,9 @@ class Step3VisionAttention(nn.Module): prefix=f"{prefix}.out_proj", disable_tp=use_data_parallel) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, + self.scale) def forward( self, @@ -696,19 +697,9 @@ class Step3VisionAttention(nn.Module): # get query proj qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - q = q.view(bsz, tgt_len, self.num_heads, self.head_dim) - k = k.view(bsz, tgt_len, self.num_heads, self.head_dim) - v = v.view(bsz, tgt_len, self.num_heads, self.head_dim) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - attn_output = F.scaled_dot_product_attention(q, - k, - v, - scale=self.scale, - is_causal=False) - attn_output = attn_output.transpose(1, 2).reshape( - bsz, tgt_len, self.num_heads * self.head_dim) + + # Use unified MultiHeadAttention with automatic backend selection + attn_output = self.attn(q, k, v) attn_output, _ = self.out_proj(attn_output) diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index de30509b1ccb4..c16aa5ac608f9 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -122,4 +122,4 @@ def resolve_visual_encoder_outputs( uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) if post_layer_norm is not None and uses_last_layer: hs_pool[-1] = post_layer_norm(encoder_outputs) - return torch.cat(hs_pool, dim=-1) + return torch.cat(hs_pool, dim=-1) \ No newline at end of file diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index fdd3764d2c35d..0cea49eece42e 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -48,6 +48,7 @@ class _Backend(enum.Enum): ROCM_AITER_MLA_VLLM_V1 = enum.auto() ROCM_AITER_FA = enum.auto() # used for ViT attn backend TORCH_SDPA = enum.auto() + TORCH_SDPA_VLLM_V1 = enum.auto() FLASHINFER = enum.auto() FLASHINFER_VLLM_V1 = enum.auto() TRITON_MLA = enum.auto() # Supported by V1