Feature/vit attention unification# 23880 (#23978)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
baonudesifeizhai 2025-09-10 09:10:14 -04:00 committed by GitHub
parent 72d30108a0
commit 6cbd41909e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 68 additions and 56 deletions

View File

@ -23,6 +23,9 @@ def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()
# Clear xformers availability cache
import vllm.attention.layer as layer_module
layer_module.USE_XFORMERS_OPS = None
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
@ -33,19 +36,28 @@ def test_mha_attn_platform(device: str):
torch.set_default_dtype(torch.float16)
if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
with patch("vllm.attention.selector.current_platform",
CpuPlatform()), \
patch("vllm.platforms.current_platform", CpuPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1
elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
with patch("vllm.attention.selector.current_platform",
RocmPlatform()), \
patch("vllm.platforms.current_platform", RocmPlatform()), \
patch("vllm.attention.layer.current_platform", RocmPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
with patch("vllm.attention.selector.current_platform",
CudaPlatform()), \
patch("vllm.platforms.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.XFORMERS
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
with patch("vllm.attention.selector.current_platform",
CudaPlatform()), \
patch("vllm.platforms.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.XFORMERS

View File

@ -360,13 +360,13 @@ class MultiHeadAttention(nn.Module):
# currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA
else:
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
_Backend.FLEX_ATTENTION):
backend = _Backend.XFORMERS
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA
_Backend.TORCH_SDPA,
_Backend.TORCH_SDPA_VLLM_V1,
_Backend.XFORMERS,
_Backend.PALLAS_VLLM_V1,
_Backend.ROCM_AITER_FA,
} else current_platform.get_vit_attn_backend()
if (self.attn_backend == _Backend.XFORMERS
and not check_xformers_availability()):
@ -413,6 +413,19 @@ class MultiHeadAttention(nn.Module):
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
out = flash_attn_varlen_func(query,
key,
value,
softmax_scale=self.scale)
else:
# ViT attention hasn't supported this backend yet
raise NotImplementedError(
f"ViT attention hasn't supported {self.attn_backend} "
f"backend yet.")
return out.reshape(bsz, q_len, -1)

View File

@ -170,6 +170,7 @@ class Idefics2VisionAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
# Use unified MultiHeadAttention with Flash Attention support
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
@ -181,6 +182,8 @@ class Idefics2VisionAttention(nn.Module):
hidden_states
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
# Use unified MultiHeadAttention implementation
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out)
return attn_output

View File

@ -255,6 +255,10 @@ class InternSdpaAttention(nn.Module):
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
self.scale)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x)
@ -268,12 +272,9 @@ class InternSdpaAttention(nn.Module):
B_, N_, H_, D_ = q.shape
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).reshape(B, N, -1)
# Use unified MultiHeadAttention with automatic backend selection
x = self.attn(q, k, v)
x = self.proj(x)
return x

View File

@ -12,10 +12,10 @@ from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from transformers.utils import torch_int
from vllm.attention.layer import MultiHeadAttention
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -206,6 +206,10 @@ class InternSdpaAttention(nn.Module):
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
self.scale)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
@ -213,20 +217,13 @@ class InternSdpaAttention(nn.Module):
k = self.k_proj(x)
v = self.v_proj(x)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).reshape(B, N, -1)
# Use unified MultiHeadAttention with automatic backend selection
x = self.attn(q, k, v)
x = self.projection_layer(x)
return x

View File

@ -35,6 +35,7 @@ from transformers.models.mllama.processing_mllama import (
import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig
@ -517,6 +518,10 @@ class MllamaVisionSdpaAttention(nn.Module):
prefix=f"{prefix}.o_proj",
)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
1.0 / math.sqrt(self.head_dim))
def forward(
self,
hidden_state: torch.Tensor,
@ -524,21 +529,10 @@ class MllamaVisionSdpaAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_state)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(q.shape[0], q.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
k = k.view(k.shape[0], k.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
v = v.view(v.shape[0], v.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
# TODO: remove padding in image encoder
attn_output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
dropout_p=0.0)
# Use unified MultiHeadAttention with automatic backend selection
attn_output = self.attn(q, k, v)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(attn_output.shape[0],
attn_output.shape[1], -1)
output, _ = self.o_proj(attn_output)

View File

@ -16,6 +16,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@ -682,9 +683,9 @@ class Step3VisionAttention(nn.Module):
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
self.scale)
def forward(
self,
@ -696,19 +697,9 @@ class Step3VisionAttention(nn.Module):
# get query proj
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim)
k = k.view(bsz, tgt_len, self.num_heads, self.head_dim)
v = v.view(bsz, tgt_len, self.num_heads, self.head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(q,
k,
v,
scale=self.scale,
is_causal=False)
attn_output = attn_output.transpose(1, 2).reshape(
bsz, tgt_len, self.num_heads * self.head_dim)
# Use unified MultiHeadAttention with automatic backend selection
attn_output = self.attn(q, k, v)
attn_output, _ = self.out_proj(attn_output)

View File

@ -122,4 +122,4 @@ def resolve_visual_encoder_outputs(
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
return torch.cat(hs_pool, dim=-1)

View File

@ -48,6 +48,7 @@ class _Backend(enum.Enum):
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
TORCH_SDPA_VLLM_V1 = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_VLLM_V1 = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1