Feature/vit attention unification# 23880 (#23978)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
baonudesifeizhai 2025-09-10 09:10:14 -04:00 committed by GitHub
parent 72d30108a0
commit 6cbd41909e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 68 additions and 56 deletions

View File

@ -23,6 +23,9 @@ def clear_cache():
"""Clear lru cache to ensure each test case runs without caching. """Clear lru cache to ensure each test case runs without caching.
""" """
_cached_get_attn_backend.cache_clear() _cached_get_attn_backend.cache_clear()
# Clear xformers availability cache
import vllm.attention.layer as layer_module
layer_module.USE_XFORMERS_OPS = None
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
@ -33,19 +36,28 @@ def test_mha_attn_platform(device: str):
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
if device == "cpu": if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()): with patch("vllm.attention.selector.current_platform",
CpuPlatform()), \
patch("vllm.platforms.current_platform", CpuPlatform()):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1
elif device == "hip": elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()): with patch("vllm.attention.selector.current_platform",
RocmPlatform()), \
patch("vllm.platforms.current_platform", RocmPlatform()), \
patch("vllm.attention.layer.current_platform", RocmPlatform()):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA assert attn.attn_backend == _Backend.TORCH_SDPA
else: else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()): with patch("vllm.attention.selector.current_platform",
CudaPlatform()), \
patch("vllm.platforms.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.XFORMERS assert attn.attn_backend == _Backend.XFORMERS
with patch("vllm.attention.selector.current_platform", CudaPlatform()): with patch("vllm.attention.selector.current_platform",
CudaPlatform()), \
patch("vllm.platforms.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 72, scale=1) attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.XFORMERS assert attn.attn_backend == _Backend.XFORMERS

View File

@ -360,13 +360,13 @@ class MultiHeadAttention(nn.Module):
# currently, only torch_sdpa is supported on rocm # currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA self.attn_backend = _Backend.TORCH_SDPA
else: else:
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
_Backend.FLEX_ATTENTION):
backend = _Backend.XFORMERS
self.attn_backend = backend if backend in { self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 _Backend.TORCH_SDPA,
} else _Backend.TORCH_SDPA _Backend.TORCH_SDPA_VLLM_V1,
_Backend.XFORMERS,
_Backend.PALLAS_VLLM_V1,
_Backend.ROCM_AITER_FA,
} else current_platform.get_vit_attn_backend()
if (self.attn_backend == _Backend.XFORMERS if (self.attn_backend == _Backend.XFORMERS
and not check_xformers_availability()): and not check_xformers_availability()):
@ -413,6 +413,19 @@ class MultiHeadAttention(nn.Module):
from torch_xla.experimental.custom_kernel import flash_attention from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale) out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2) out = out.transpose(1, 2)
elif self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
out = flash_attn_varlen_func(query,
key,
value,
softmax_scale=self.scale)
else:
# ViT attention hasn't supported this backend yet
raise NotImplementedError(
f"ViT attention hasn't supported {self.attn_backend} "
f"backend yet.")
return out.reshape(bsz, q_len, -1) return out.reshape(bsz, q_len, -1)

View File

@ -170,6 +170,7 @@ class Idefics2VisionAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
) )
# Use unified MultiHeadAttention with Flash Attention support
self.attn = MultiHeadAttention(self.num_heads_per_partition, self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale) self.head_dim, self.scale)
@ -181,6 +182,8 @@ class Idefics2VisionAttention(nn.Module):
hidden_states hidden_states
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1) query_states, key_states, value_states = qkv.chunk(3, dim=-1)
# Use unified MultiHeadAttention implementation
out = self.attn(query_states, key_states, value_states) out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out) attn_output, _ = self.out_proj(out)
return attn_output return attn_output

View File

@ -255,6 +255,10 @@ class InternSdpaAttention(nn.Module):
self.proj = nn.Linear(self.dummy_dim, self.embed_dim) self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
self.scale)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape B, N, C = x.shape
qkv = self.qkv(x) qkv = self.qkv(x)
@ -268,12 +272,9 @@ class InternSdpaAttention(nn.Module):
B_, N_, H_, D_ = q.shape B_, N_, H_, D_ = q.shape
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) # Use unified MultiHeadAttention with automatic backend selection
x = x.transpose(1, 2).reshape(B, N, -1) x = self.attn(q, k, v)
x = self.proj(x) x = self.proj(x)
return x return x

View File

@ -12,10 +12,10 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.layer import MultiHeadAttention
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -206,6 +206,10 @@ class InternSdpaAttention(nn.Module):
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim) self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
self.scale)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape B, N, C = x.shape
@ -213,20 +217,13 @@ class InternSdpaAttention(nn.Module):
k = self.k_proj(x) k = self.k_proj(x)
v = self.v_proj(x) v = self.v_proj(x)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if self.qk_normalization: if self.qk_normalization:
B_, N_, H_, D_ = q.shape B_, N_, H_, D_ = q.shape
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) # Use unified MultiHeadAttention with automatic backend selection
x = x.transpose(1, 2).reshape(B, N, -1) x = self.attn(q, k, v)
x = self.projection_layer(x) x = self.projection_layer(x)
return x return x

View File

@ -35,6 +35,7 @@ from transformers.models.mllama.processing_mllama import (
import vllm.distributed.parallel_state as ps import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.selector import _Backend from vllm.attention.selector import _Backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -517,6 +518,10 @@ class MllamaVisionSdpaAttention(nn.Module):
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
) )
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
1.0 / math.sqrt(self.head_dim))
def forward( def forward(
self, self,
hidden_state: torch.Tensor, hidden_state: torch.Tensor,
@ -524,21 +529,10 @@ class MllamaVisionSdpaAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_state) qkv, _ = self.qkv_proj(hidden_state)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(q.shape[0], q.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
k = k.view(k.shape[0], k.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
v = v.view(v.shape[0], v.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
# TODO: remove padding in image encoder # Use unified MultiHeadAttention with automatic backend selection
attn_output = F.scaled_dot_product_attention(q, attn_output = self.attn(q, k, v)
k,
v,
attn_mask=attention_mask,
dropout_p=0.0)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(attn_output.shape[0], attn_output = attn_output.reshape(attn_output.shape[0],
attn_output.shape[1], -1) attn_output.shape[1], -1)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)

View File

@ -16,6 +16,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, TensorType from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@ -682,9 +683,9 @@ class Step3VisionAttention(nn.Module):
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel) disable_tp=use_data_parallel)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): # Use unified MultiHeadAttention with automatic backend selection
return tensor.view(bsz, seq_len, self.num_heads, self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
self.head_dim).transpose(1, 2).contiguous() self.scale)
def forward( def forward(
self, self,
@ -696,19 +697,9 @@ class Step3VisionAttention(nn.Module):
# get query proj # get query proj
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim)
k = k.view(bsz, tgt_len, self.num_heads, self.head_dim) # Use unified MultiHeadAttention with automatic backend selection
v = v.view(bsz, tgt_len, self.num_heads, self.head_dim) attn_output = self.attn(q, k, v)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(q,
k,
v,
scale=self.scale,
is_causal=False)
attn_output = attn_output.transpose(1, 2).reshape(
bsz, tgt_len, self.num_heads * self.head_dim)
attn_output, _ = self.out_proj(attn_output) attn_output, _ = self.out_proj(attn_output)

View File

@ -48,6 +48,7 @@ class _Backend(enum.Enum):
ROCM_AITER_MLA_VLLM_V1 = enum.auto() ROCM_AITER_MLA_VLLM_V1 = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto() TORCH_SDPA = enum.auto()
TORCH_SDPA_VLLM_V1 = enum.auto()
FLASHINFER = enum.auto() FLASHINFER = enum.auto()
FLASHINFER_VLLM_V1 = enum.auto() FLASHINFER_VLLM_V1 = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1 TRITON_MLA = enum.auto() # Supported by V1